diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 37dfec9..520f6ed 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,4 +7,14 @@ repos: - repo: https://gitlab.com/pycqa/flake8 rev: 3.9.2 hooks: - - id: flake8 \ No newline at end of file + - id: flake8 +- repo: local + hooks: + - id: pdoc + name: pdoc + description: 'pdoc3: Auto-generate API documentation for Python projects' + entry: pdoc --html --skip-errors --force -o docs/api carvekit + language: python + language_version: python3 + require_serial: true + types: [python] diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 975093f..7f537db 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -36,12 +36,15 @@ ENV CARVEKIT_PORT '5000' ENV CARVEKIT_HOST '0.0.0.0' ENV CARVEKIT_SEGMENTATION_NETWORK 'tracer_b7' ENV CARVEKIT_PREPROCESSING_METHOD 'none' -ENV CARVEKIT_POSTPROCESSING_METHOD 'fba' +ENV CARVEKIT_POSTPROCESSING_METHOD 'cascade_fba' ENV CARVEKIT_DEVICE 'cpu' +ENV CARVEKIT_BATCH_SIZE_PRE=5 ENV CARVEKIT_BATCH_SIZE_SEG '5' ENV CARVEKIT_BATCH_SIZE_MATTING '1' +ENV CARVEKIT_BATCH_SIZE_REFINE '1' ENV CARVEKIT_SEG_MASK_SIZE '640' ENV CARVEKIT_MATTING_MASK_SIZE '2048' +ENV CARVEKIT_REFINE_MASK_SIZE '900' ENV CARVEKIT_AUTH_ENABLE '1' ENV CARVEKIT_FP16 '0' ENV CARVEKIT_TRIMAP_PROB_THRESHOLD=231 diff --git a/Dockerfile.cuda b/Dockerfile.cuda index b5d31df..1155b0c 100644 --- a/Dockerfile.cuda +++ b/Dockerfile.cuda @@ -36,12 +36,15 @@ ENV CARVEKIT_PORT '5000' ENV CARVEKIT_HOST '0.0.0.0' ENV CARVEKIT_SEGMENTATION_NETWORK 'tracer_b7' ENV CARVEKIT_PREPROCESSING_METHOD 'none' -ENV CARVEKIT_POSTPROCESSING_METHOD 'fba' +ENV CARVEKIT_POSTPROCESSING_METHOD 'cascade_fba' ENV CARVEKIT_DEVICE 'cuda' +ENV CARVEKIT_BATCH_SIZE_PRE=5 ENV CARVEKIT_BATCH_SIZE_SEG '5' ENV CARVEKIT_BATCH_SIZE_MATTING '1' +ENV CARVEKIT_BATCH_SIZE_REFINE '1' ENV CARVEKIT_SEG_MASK_SIZE '640' ENV CARVEKIT_MATTING_MASK_SIZE '2048' +ENV CARVEKIT_REFINE_MASK_SIZE '900' ENV CARVEKIT_AUTH_ENABLE '1' ENV CARVEKIT_FP16 '0' ENV CARVEKIT_TRIMAP_PROB_THRESHOLD=231 diff --git a/README.md b/README.md index 8f44cc7..cb20f72 100644 --- a/README.md +++ b/README.md @@ -26,13 +26,16 @@ Automated high-quality background removal framework for an image using neural ne ## πŸŽ† Features: - High Quality +- Works offline - Batch Processing - NVIDIA CUDA and CPU processing - FP16 inference: Fast inference with low memory usage - Easy inference - 100% remove.bg compatible FastAPI HTTP API - Removes background from hairs +- Automatic best method selection for user's image - Easy integration with your code +- Models hosted on [HuggingFace](https://huggingface.co/Carve) ## β›± Try yourself on [Google Colab](https://colab.research.google.com/github/OPHoperHPO/image-background-remove-tool/blob/master/docs/other/carvekit_try.ipynb) ## ⛓️ How does it work? @@ -64,10 +67,17 @@ It can be briefly described as ## πŸ–ΌοΈ Image pre-processing and post-processing methods: ### πŸ” Preprocessing methods: * `none` - No preprocessing methods used. -> They will be added in the future. +* [`autoscene`](https://huggingface.co/Carve/scene_classifier/) - Automatically detects the scene type using classifier and applies the appropriate model. (default) +* `auto` - Performs in-depth image analysis and more accurately determines the best background removal method. Uses object classifier and scene classifier together. +> ### Notes: +> 1. `AutoScene` and `auto` may override the model and parameters specified by the user without logging. +> So, if you want to use a specific model, make all constant etc., you should disable auto preprocessing methods first! +> 2. At the moment for `auto` method universal models are selected for some specific domains, since the added models are currently not enough for so many types of scenes. +> In the future, when some variety of models is added, auto-selection will be rewritten for the better. ### βœ‚ Post-processing methods: * `none` - No post-processing methods used. -* `fba` (default) - This algorithm improves the borders of the image when removing the background from images with hair, etc. using FBA Matting neural network. This method gives the best result in combination with u2net without any preprocessing methods. +* `fba` - This algorithm improves the borders of the image when removing the background from images with hair, etc. using FBA Matting neural network. +* `cascade_fba` (default) - This algorithm refines the segmentation mask using CascadePSP neural network and then applies the FBA algorithm. ## 🏷 Setup for CPU processing: 1. `pip install carvekit --extra-index-url https://download.pytorch.org/whl/cpu` @@ -84,12 +94,15 @@ import torch from carvekit.api.high import HiInterface # Check doc strings for more information -interface = HiInterface(object_type="hairs-like", # Can be "object" or "hairs-like". +interface = HiInterface(object_type="auto", # Can be "object" or "hairs-like" or "auto" batch_size_seg=5, + batch_size_pre=5, batch_size_matting=1, + batch_size_refine=1, device='cuda' if torch.cuda.is_available() else 'cpu', seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net matting_mask_size=2048, + refine_mask_size=900, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, @@ -100,33 +113,65 @@ cat_wo_bg.save('2.png') ``` - +### Analogue of `auto` preprocessing method from cli +``` python +from carvekit.api.autointerface import AutoInterface +from carvekit.ml.wrap.scene_classifier import SceneClassifier +from carvekit.ml.wrap.yolov4 import SimplifiedYoloV4 + +scene_classifier = SceneClassifier(device="cpu", batch_size=1) +object_classifier = SimplifiedYoloV4(device="cpu", batch_size=1) + +interface = AutoInterface(scene_classifier=scene_classifier, + object_classifier=object_classifier, + segmentation_batch_size=1, + postprocessing_batch_size=1, + postprocessing_image_size=2048, + refining_batch_size=1, + refining_image_size=900, + segmentation_device="cpu", + fp16=False, + postprocessing_device="cpu") +images_without_background = interface(['./tests/data/cat.jpg']) +cat_wo_bg = images_without_background[0] +cat_wo_bg.save('2.png') +``` ### If you want control everything ``` python import PIL.Image from carvekit.api.interface import Interface from carvekit.ml.wrap.fba_matting import FBAMatting +from carvekit.ml.wrap.scene_classifier import SceneClassifier +from carvekit.ml.wrap.cascadepsp import CascadePSP from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 -from carvekit.pipelines.postprocessing import MattingMethod -from carvekit.pipelines.preprocessing import PreprocessingStub +from carvekit.pipelines.postprocessing import CasMattingMethod +from carvekit.pipelines.preprocessing import AutoScene from carvekit.trimap.generator import TrimapGenerator # Check doc strings for more information seg_net = TracerUniversalB7(device='cpu', - batch_size=1) - + batch_size=1, fp16=False) +cascade_psp = CascadePSP(device='cpu', + batch_size=1, + input_tensor_size=900, + fp16=False, + processing_accelerate_image_size=2048, + global_step_only=False) fba = FBAMatting(device='cpu', input_tensor_size=2048, - batch_size=1) + batch_size=1, fp16=False) -trimap = TrimapGenerator() +trimap = TrimapGenerator(prob_threshold=231, kernel_size=30, erosion_iters=5) -preprocessing = PreprocessingStub() +scene_classifier = SceneClassifier(device='cpu', batch_size=5) +preprocessing = AutoScene(scene_classifier=scene_classifier) -postprocessing = MattingMethod(matting_module=fba, - trimap_generator=trimap, - device='cpu') +postprocessing = CasMattingMethod( + refining_module=cascade_psp, + matting_module=fba, + trimap_generator=trimap, + device='cpu') interface = Interface(pre_pipe=preprocessing, post_pipe=postprocessing, @@ -134,8 +179,7 @@ interface = Interface(pre_pipe=preprocessing, image = PIL.Image.open('tests/data/cat.jpg') cat_wo_bg = interface([image])[0] -cat_wo_bg.save('2.png') - +cat_wo_bg.save('2.png') ``` @@ -151,24 +195,35 @@ Usage: carvekit [OPTIONS] Options: -i ./2.jpg Path to input file or dir [required] -o ./2.png Path to output file or dir - --pre none Preprocessing method - --post fba Postprocessing method. + --pre autoscene Preprocessing method + --post cascade_fba Postprocessing method. --net tracer_b7 Segmentation Network. Check README for more info. + --recursive Enables recursive search for images in a folder --batch_size 10 Batch Size for list of images to be loaded to RAM - + + --batch_size_pre 5 Batch size for list of images to be + processed by preprocessing method network + --batch_size_seg 5 Batch size for list of images to be processed by segmentation network --batch_size_mat 1 Batch size for list of images to be processed by matting network + --batch_size_refine 1 Batch size for list of images to be + processed by refining network + --seg_mask_size 640 The size of the input image for the segmentation neural network. Use 640 for Tracer B7 and 320 for U2Net --matting_mask_size 2048 The size of the input image for the matting neural network. + + --refine_mask_size 900 The size of the input image for the refining + neural network. + --trimap_dilation 30 The size of the offset radius from the object mask in pixels when forming an unknown area diff --git a/carvekit/__init__.py b/carvekit/__init__.py index b58821b..03a3882 100644 --- a/carvekit/__init__.py +++ b/carvekit/__init__.py @@ -1 +1 @@ -version = "4.1.0" +version = "4.5.0" diff --git a/carvekit/__main__.py b/carvekit/__main__.py index acf901d..c40bc61 100644 --- a/carvekit/__main__.py +++ b/carvekit/__main__.py @@ -16,8 +16,8 @@ ) @click.option("-i", required=True, type=str, help="Path to input file or dir") @click.option("-o", default="none", type=str, help="Path to output file or dir") -@click.option("--pre", default="none", type=str, help="Preprocessing method") -@click.option("--post", default="fba", type=str, help="Postprocessing method.") +@click.option("--pre", default="autoscene", type=str, help="Preprocessing method") +@click.option("--post", default="cascade_fba", type=str, help="Postprocessing method.") @click.option("--net", default="tracer_b7", type=str, help="Segmentation Network") @click.option( "--recursive", @@ -31,6 +31,12 @@ type=int, help="Batch Size for list of images to be loaded to RAM", ) +@click.option( + "--batch_size_pre", + default=5, + type=int, + help="Batch size for list of images to be processed by preprocessing method network", +) @click.option( "--batch_size_seg", default=5, @@ -43,6 +49,12 @@ type=int, help="Batch size for list of images to be processed by matting " "network", ) +@click.option( + "--batch_size_refine", + default=1, + type=int, + help="Batch size for list of images to be processed by refining network", +) @click.option( "--seg_mask_size", default=640, @@ -55,6 +67,12 @@ type=int, help="The size of the input image for the matting neural network.", ) +@click.option( + "--refine_mask_size", + default=900, + type=int, + help="The size of the input image for the refining neural network.", +) @click.option( "--trimap_dilation", default=30, @@ -89,10 +107,13 @@ def removebg( net: str, recursive: bool, batch_size: int, + batch_size_pre: int, batch_size_seg: int, batch_size_mat: int, + batch_size_refine: int, seg_mask_size: int, matting_mask_size: int, + refine_mask_size: int, device: str, fp16: bool, trimap_dilation: int, @@ -121,12 +142,15 @@ def removebg( device=device, batch_size_seg=batch_size_seg, batch_size_matting=batch_size_mat, + batch_size_refine=batch_size_refine, seg_mask_size=seg_mask_size, matting_mask_size=matting_mask_size, + refine_mask_size=refine_mask_size, fp16=fp16, trimap_dilation=trimap_dilation, trimap_erosion=trimap_erosion, trimap_prob_threshold=trimap_prob_threshold, + batch_size_pre=batch_size_pre, ) interface = init_interface(interface_config) diff --git a/carvekit/api/autointerface.py b/carvekit/api/autointerface.py new file mode 100644 index 0000000..60ba0c7 --- /dev/null +++ b/carvekit/api/autointerface.py @@ -0,0 +1,252 @@ +""" +Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0 +""" +from collections import Counter +from pathlib import Path + +from PIL import Image +from typing import Union, List, Dict + +from carvekit.api.interface import Interface +from carvekit.ml.wrap.basnet import BASNET +from carvekit.ml.wrap.cascadepsp import CascadePSP +from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 +from carvekit.ml.wrap.fba_matting import FBAMatting +from carvekit.ml.wrap.scene_classifier import SceneClassifier +from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 +from carvekit.ml.wrap.u2net import U2NET +from carvekit.ml.wrap.yolov4 import SimplifiedYoloV4 +from carvekit.pipelines.postprocessing import CasMattingMethod +from carvekit.trimap.generator import TrimapGenerator + +__all__ = ["AutoInterface"] + +from carvekit.utils.image_utils import load_image + +from carvekit.utils.pool_utils import thread_pool_processing + + +class AutoInterface(Interface): + def __init__( + self, + scene_classifier: SceneClassifier, + object_classifier: SimplifiedYoloV4, + segmentation_batch_size: int = 3, + refining_batch_size: int = 1, + refining_image_size: int = 900, + postprocessing_batch_size: int = 1, + postprocessing_image_size: int = 2048, + segmentation_device: str = "cpu", + postprocessing_device: str = "cpu", + fp16=False, + ): + """ + Args: + scene_classifier: SceneClassifier instance + object_classifier: YoloV4_COCO instance + """ + self.scene_classifier = scene_classifier + self.object_classifier = object_classifier + self.segmentation_batch_size = segmentation_batch_size + self.refining_batch_size = refining_batch_size + self.refining_image_size = refining_image_size + self.postprocessing_batch_size = postprocessing_batch_size + self.postprocessing_image_size = postprocessing_image_size + self.segmentation_device = segmentation_device + self.postprocessing_device = postprocessing_device + self.fp16 = fp16 + super().__init__( + seg_pipe=None, post_pipe=None, pre_pipe=None + ) # just for compatibility with Interface class + + @staticmethod + def select_params_for_net(net: Union[TracerUniversalB7, U2NET, DeepLabV3]): + """ + Selects the parameters for the network depending on the scene + + Args: + net: network + """ + if net == TracerUniversalB7: + return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5} + elif net == U2NET: + return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5} + elif net == DeepLabV3: + return {"prob_threshold": 231, "kernel_size": 40, "erosion_iters": 20} + elif net == BASNET: + return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5} + else: + raise ValueError("Unknown network type") + + def select_net(self, scene: str, images_info: List[dict]): + # TODO: Update this function, when new networks will be added + if scene == "hard": + for image_info in images_info: + objects = image_info["objects"] + if len(objects) == 0: + image_info[ + "net" + ] = TracerUniversalB7 # It seems that the image is empty, but we will try to process it + continue + obj_counter: Dict = dict(Counter([obj for obj in objects])) + # fill empty classes + for _tag in self.object_classifier.db: + if _tag not in obj_counter: + obj_counter[_tag] = 0 + + non_empty_classes = [obj for obj in obj_counter if obj_counter[obj] > 0] + + if obj_counter["human"] > 0 and len(non_empty_classes) == 1: + # Human only case. Hard Scene? It may be a photo of a person in far/middle distance. + image_info["net"] = TracerUniversalB7 + # TODO: will use DeepLabV3+ for this image, it is more suitable for this case, + # but needs checks for small bbox + elif obj_counter["human"] > 0 and len(non_empty_classes) > 1: + # Okay, we have a human without extra hairs and something else. Hard border + image_info["net"] = TracerUniversalB7 + elif obj_counter["cars"] > 0: + # Cars case + image_info["net"] = TracerUniversalB7 + elif obj_counter["animals"] > 0: + # Animals case + image_info["net"] = U2NET # animals should be always in soft scenes + else: + # We have no idea what is in the image, so we will try to process it with universal model + image_info["net"] = TracerUniversalB7 + + elif scene == "soft": + for image_info in images_info: + objects = image_info["objects"] + if len(objects) == 0: + image_info[ + "net" + ] = TracerUniversalB7 # It seems that the image is empty, but we will try to process it + continue + obj_counter: Dict = dict(Counter([obj for obj in objects])) + # fill empty classes + for _tag in self.object_classifier.db: + if _tag not in obj_counter: + obj_counter[_tag] = 0 + + non_empty_classes = [obj for obj in obj_counter if obj_counter[obj] > 0] + + if obj_counter["human"] > 0 and len(non_empty_classes) == 1: + # Human only case. It may be a portrait + image_info["net"] = U2NET + elif obj_counter["human"] > 0 and len(non_empty_classes) > 1: + # Okay, we have a human with hairs and something else + image_info["net"] = U2NET + elif obj_counter["cars"] > 0: + # Cars case. + image_info["net"] = TracerUniversalB7 + elif obj_counter["animals"] > 0: + # Animals case + image_info["net"] = U2NET # animals should be always in soft scenes + else: + # We have no idea what is in the image, so we will try to process it with universal model + image_info["net"] = TracerUniversalB7 + elif scene == "digital": + for image_info in images_info: # TODO: not implemented yet + image_info[ + "net" + ] = TracerUniversalB7 # It seems that the image is empty, but we will try to process it + + def __call__(self, images: List[Union[str, Path, Image.Image]]): + """ + Automatically detects the scene and selects the appropriate network for segmentation + + Args: + interface: Interface instance + images: list of images + + Returns: + list of masks + """ + loaded_images = thread_pool_processing(load_image, images) + + scene_analysis = self.scene_classifier(loaded_images) + images_objects = self.object_classifier(loaded_images) + + images_per_scene = {} + for i, image in enumerate(loaded_images): + scene_name = scene_analysis[i][0][0] + if scene_name not in images_per_scene: + images_per_scene[scene_name] = [] + images_per_scene[scene_name].append( + {"image": image, "objects": images_objects[i]} + ) + + for scene_name, images_info in list(images_per_scene.items()): + self.select_net(scene_name, images_info) + + # groups images by net + for scene_name, images_info in list(images_per_scene.items()): + groups = {} + for image_info in images_info: + net = image_info["net"] + if net not in groups: + groups[net] = [] + groups[net].append(image_info) + for net, gimages_info in list(groups.items()): + sc_images = [image_info["image"] for image_info in gimages_info] + masks = net( + device=self.segmentation_device, + batch_size=self.segmentation_batch_size, + fp16=self.fp16, + )(sc_images) + + for i, image_info in enumerate(gimages_info): + image_info["mask"] = masks[i] + + cascadepsp = CascadePSP( + device=self.postprocessing_device, + fp16=self.fp16, + input_tensor_size=self.refining_image_size, + batch_size=self.refining_batch_size, + ) + + fba = FBAMatting( + device=self.postprocessing_device, + batch_size=self.postprocessing_batch_size, + input_tensor_size=self.postprocessing_image_size, + fp16=self.fp16, + ) + # groups images by net + for scene_name, images_info in list(images_per_scene.items()): + groups = {} + for image_info in images_info: + net = image_info["net"] + if net not in groups: + groups[net] = [] + groups[net].append(image_info) + for net, gimages_info in list(groups.items()): + sc_images = [image_info["image"] for image_info in gimages_info] + # noinspection PyArgumentList + trimap_generator = TrimapGenerator(**self.select_params_for_net(net)) + matting_method = CasMattingMethod( + refining_module=cascadepsp, + matting_module=fba, + trimap_generator=trimap_generator, + device=self.postprocessing_device, + ) + masks = [image_info["mask"] for image_info in gimages_info] + result = matting_method(sc_images, masks) + + for i, image_info in enumerate(gimages_info): + image_info["result"] = result[i] + + # Reconstructing the original order of image + result = [] + for image in loaded_images: + for scene_name, images_info in list(images_per_scene.items()): + for image_info in images_info: + if image_info["image"] == image: + result.append(image_info["result"]) + break + if len(result) != len(images): + raise RuntimeError( + "Something went wrong with restoring original order. Please report this bug." + ) + return result diff --git a/carvekit/api/high.py b/carvekit/api/high.py index 46fb9d3..dda60f3 100644 --- a/carvekit/api/high.py +++ b/carvekit/api/high.py @@ -1,6 +1,8 @@ """ Source url: https://github.com/OPHoperHPO/image-background-remove-tool + Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. + License: Apache License 2.0 """ import warnings @@ -8,20 +10,26 @@ from carvekit.api.interface import Interface from carvekit.ml.wrap.fba_matting import FBAMatting from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 +from carvekit.ml.wrap.cascadepsp import CascadePSP +from carvekit.ml.wrap.scene_classifier import SceneClassifier +from carvekit.pipelines.preprocessing import AutoScene from carvekit.ml.wrap.u2net import U2NET -from carvekit.pipelines.postprocessing import MattingMethod +from carvekit.pipelines.postprocessing import CasMattingMethod from carvekit.trimap.generator import TrimapGenerator class HiInterface(Interface): def __init__( self, - object_type: str = "object", + object_type: str = "auto", + batch_size_pre=5, batch_size_seg=2, batch_size_matting=1, + batch_size_refine=1, device="cpu", seg_mask_size=640, matting_mask_size=2048, + refine_mask_size=900, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, @@ -31,69 +39,96 @@ def __init__( Initializes High Level interface. Args: - object_type: Interest object type. Can be "object" or "hairs-like". - matting_mask_size: The size of the input image for the matting neural network. - seg_mask_size: The size of the input image for the segmentation neural network. - batch_size_seg: Number of images processed per one segmentation neural network call. - batch_size_matting: Number of images processed per one matting neural network call. - device: Processing device - fp16: Use half precision. Reduce memory usage and increase speed. Experimental support - trimap_prob_threshold: Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied - trimap_dilation: The size of the offset radius from the object mask in pixels when forming an unknown area - trimap_erosion_iters: The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area + object_type (str, default=object): Interest object type. Can be "object" or "hairs-like". + matting_mask_size (int, default=2048): The size of the input image for the matting neural network. + seg_mask_size (int, default=640): The size of the input image for the segmentation neural network. + batch_size_pre (int, default=5: Number of images processed per one preprocessing method call. + batch_size_seg (int, default=2): Number of images processed per one segmentation neural network call. + batch_size_matting (int, matting=1): Number of images processed per one matting neural network call. + device (Literal[cpu, cuda], default=cpu): Processing device + fp16 (bool, default=False): Use half precision. Reduce memory usage and increase speed. + .. CAUTION:: ⚠️ **Experimental support** + trimap_prob_threshold (int, default=231): Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied + trimap_dilation (int, default=30): The size of the offset radius from the object mask in pixels when forming an unknown area + trimap_erosion_iters (int, default=5): The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area + refine_mask_size (int, default=900): The size of the input image for the refinement neural network. + batch_size_refine (int, default=1): Number of images processed per one refinement neural network call. - Notes: - 1. Changing seg_mask_size may cause an out-of-memory error if the value is too large, and it may also - result in reduced precision. I do not recommend changing this value. You can change matting_mask_size in - range from (1024 to 4096) to improve object edge refining quality, but it will cause extra large RAM and + + .. NOTE:: + 1. Changing seg_mask_size may cause an `out-of-memory` error if the value is too large, and it may also + result in reduced precision. I do not recommend changing this value. You can change `matting_mask_size` in + range from `(1024 to 4096)` to improve object edge refining quality, but it will cause extra large RAM and video memory consume. Also, you can change batch size to accelerate background removal, but it also causes extra large video memory consume, if value is too big. - - 2. Changing trimap_prob_threshold, trimap_kernel_size, trimap_erosion_iters may improve object edge - refining quality, + 2. Changing `trimap_prob_threshold`, `trimap_kernel_size`, `trimap_erosion_iters` may improve object edge + refining quality. """ + preprocess_pipeline = None + if object_type == "object": - self.u2net = TracerUniversalB7( + self._segnet = TracerUniversalB7( device=device, batch_size=batch_size_seg, input_image_size=seg_mask_size, fp16=fp16, ) elif object_type == "hairs-like": - self.u2net = U2NET( + self._segnet = U2NET( device=device, batch_size=batch_size_seg, input_image_size=seg_mask_size, fp16=fp16, ) + elif object_type == "auto": + # Using Tracer by default, + # but it will dynamically switch to other if needed + self._segnet = TracerUniversalB7( + device=device, + batch_size=batch_size_seg, + input_image_size=seg_mask_size, + fp16=fp16, + ) + self._scene_classifier = SceneClassifier( + device=device, fp16=fp16, batch_size=batch_size_pre + ) + preprocess_pipeline = AutoScene(scene_classifier=self._scene_classifier) + else: warnings.warn( f"Unknown object type: {object_type}. Using default object type: object" ) - self.u2net = TracerUniversalB7( + self._segnet = TracerUniversalB7( device=device, batch_size=batch_size_seg, input_image_size=seg_mask_size, fp16=fp16, ) - self.fba = FBAMatting( + self._cascade_psp = CascadePSP( + device=device, + batch_size=batch_size_refine, + input_tensor_size=refine_mask_size, + fp16=fp16, + ) + self._fba = FBAMatting( batch_size=batch_size_matting, device=device, input_tensor_size=matting_mask_size, fp16=fp16, ) - self.trimap_generator = TrimapGenerator( + self._trimap_generator = TrimapGenerator( prob_threshold=trimap_prob_threshold, kernel_size=trimap_dilation, erosion_iters=trimap_erosion_iters, ) super(HiInterface, self).__init__( - pre_pipe=None, - seg_pipe=self.u2net, - post_pipe=MattingMethod( - matting_module=self.fba, - trimap_generator=self.trimap_generator, + pre_pipe=preprocess_pipeline, + seg_pipe=self._segnet, + post_pipe=CasMattingMethod( + refining_module=self._cascade_psp, + matting_module=self._fba, + trimap_generator=self._trimap_generator, device=device, ), device=device, diff --git a/carvekit/api/interface.py b/carvekit/api/interface.py index 364d247..88dbe98 100644 --- a/carvekit/api/interface.py +++ b/carvekit/api/interface.py @@ -12,8 +12,8 @@ from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 from carvekit.ml.wrap.u2net import U2NET from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 -from carvekit.pipelines.preprocessing import PreprocessingStub -from carvekit.pipelines.postprocessing import MattingMethod +from carvekit.pipelines.preprocessing import PreprocessingStub, AutoScene +from carvekit.pipelines.postprocessing import MattingMethod, CasMattingMethod from carvekit.utils.image_utils import load_image from carvekit.utils.mask_utils import apply_mask from carvekit.utils.pool_utils import thread_pool_processing @@ -22,19 +22,19 @@ class Interface: def __init__( self, - seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7], - pre_pipe: Optional[Union[PreprocessingStub]] = None, - post_pipe: Optional[Union[MattingMethod]] = None, + seg_pipe: Optional[Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]], + pre_pipe: Optional[Union[PreprocessingStub, AutoScene]] = None, + post_pipe: Optional[Union[MattingMethod, CasMattingMethod]] = None, device="cpu", ): """ Initializes an object for interacting with pipelines and other components of the CarveKit framework. Args: - pre_pipe: Initialized pre-processing pipeline object - seg_pipe: Initialized segmentation network object - post_pipe: Initialized postprocessing pipeline object - device: The processing device that will be used to apply the masks to the images. + pre_pipe (Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]): Initialized pre-processing pipeline object + seg_pipe (Optional[Union[PreprocessingStub]]): Initialized segmentation network object + post_pipe (Optional[Union[MattingMethod]]): Initialized postprocessing pipeline object + device (Literal[cpu, cuda], default=cpu): The processing device that will be used to apply the masks to the images. """ self.device = device self.preprocessing_pipeline = pre_pipe @@ -53,6 +53,11 @@ def __call__( Returns: List of images without background as PIL.Image.Image instances """ + if self.segmentation_pipeline is None: + raise ValueError( + "Segmentation pipeline is not initialized." + "Override the class or pass the pipeline to the constructor." + ) images = thread_pool_processing(load_image, images) if self.preprocessing_pipeline is not None: masks: List[Image.Image] = self.preprocessing_pipeline( diff --git a/carvekit/ml/arch/cascadepsp/__init__.py b/carvekit/ml/arch/cascadepsp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/carvekit/ml/arch/cascadepsp/extractors.py b/carvekit/ml/arch/cascadepsp/extractors.py new file mode 100644 index 0000000..7967796 --- /dev/null +++ b/carvekit/ml/arch/cascadepsp/extractors.py @@ -0,0 +1,127 @@ +""" +Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/hkchengrex/CascadePSP +License: MIT License +""" +import math + +import torch.nn as nn + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False, + ) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + padding=dilation, + bias=False, + ) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers=(3, 4, 23, 3)): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + x_1 = self.conv1(x) # /2 + x = self.bn1(x_1) + x = self.relu(x) + x = self.maxpool(x) # /2 + + x_2 = self.layer1(x) + x = self.layer2(x_2) # /2 + x = self.layer3(x) + x = self.layer4(x) + + return x, x_1, x_2 + + +def resnet50(): + model = ResNet(Bottleneck, [3, 4, 6, 3]) + return model diff --git a/carvekit/ml/arch/cascadepsp/pspnet.py b/carvekit/ml/arch/cascadepsp/pspnet.py new file mode 100644 index 0000000..350719e --- /dev/null +++ b/carvekit/ml/arch/cascadepsp/pspnet.py @@ -0,0 +1,194 @@ +""" +Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/hkchengrex/CascadePSP +License: MIT License +""" + +import torch +from torch import nn +from torch.nn import functional as F +from carvekit.ml.arch.cascadepsp.extractors import resnet50 + + +class PSPModule(nn.Module): + def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)): + super().__init__() + self.stages = [] + self.stages = nn.ModuleList( + [self._make_stage(features, size) for size in sizes] + ) + self.bottleneck = nn.Conv2d( + features * (len(sizes) + 1), out_features, kernel_size=1 + ) + self.relu = nn.ReLU(inplace=True) + + def _make_stage(self, features, size): + prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) + conv = nn.Conv2d(features, features, kernel_size=1, bias=False) + return nn.Sequential(prior, conv) + + def forward(self, feats): + h, w = feats.size(2), feats.size(3) + set_priors = [ + F.interpolate( + input=stage(feats), size=(h, w), mode="bilinear", align_corners=False + ) + for stage in self.stages + ] + priors = set_priors + [feats] + bottle = self.bottleneck(torch.cat(priors, 1)) + return self.relu(bottle) + + +class PSPUpsample(nn.Module): + def __init__(self, x_channels, in_channels, out_channels): + super().__init__() + self.conv = nn.Sequential( + nn.BatchNorm2d(in_channels), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels, out_channels, 3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1), + ) + + self.conv2 = nn.Sequential( + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1), + ) + + self.shortcut = nn.Conv2d(x_channels, out_channels, kernel_size=1) + + def forward(self, x, up): + x = F.interpolate(input=x, scale_factor=2, mode="bilinear", align_corners=False) + + p = self.conv(torch.cat([x, up], 1).type(x.type())) + sc = self.shortcut(x) + + p = p + sc + + p2 = self.conv2(p) + + return p + p2 + + +class RefinementModule(nn.Module): + def __init__(self): + super().__init__() + + self.feats = resnet50() + self.psp = PSPModule(2048, 1024, (1, 2, 3, 6)) + + self.up_1 = PSPUpsample(1024, 1024 + 256, 512) + self.up_2 = PSPUpsample(512, 512 + 64, 256) + self.up_3 = PSPUpsample(256, 256 + 3, 32) + + self.final_28 = nn.Sequential( + nn.Conv2d(1024, 32, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(32, 1, kernel_size=1), + ) + + self.final_56 = nn.Sequential( + nn.Conv2d(512, 32, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(32, 1, kernel_size=1), + ) + + self.final_11 = nn.Conv2d(32 + 3, 32, kernel_size=1) + self.final_21 = nn.Conv2d(32, 1, kernel_size=1) + + def forward(self, x, seg, inter_s8=None, inter_s4=None): + + images = {} + + """ + First iteration, s8 output + """ + if inter_s8 is None: + p = torch.cat((x, seg, seg, seg), 1) + + f, f_1, f_2 = self.feats(p) + p = self.psp(f) + + inter_s8 = self.final_28(p) + r_inter_s8 = F.interpolate( + inter_s8, scale_factor=8, mode="bilinear", align_corners=False + ) + r_inter_tanh_s8 = torch.tanh(r_inter_s8) + + images["pred_28"] = torch.sigmoid(r_inter_s8) + images["out_28"] = r_inter_s8 + else: + r_inter_tanh_s8 = inter_s8 + + """ + Second iteration, s8 output + """ + if inter_s4 is None: + p = torch.cat((x, seg, r_inter_tanh_s8, r_inter_tanh_s8), 1) + + f, f_1, f_2 = self.feats(p) + p = self.psp(f) + inter_s8_2 = self.final_28(p) + r_inter_s8_2 = F.interpolate( + inter_s8_2, scale_factor=8, mode="bilinear", align_corners=False + ) + r_inter_tanh_s8_2 = torch.tanh(r_inter_s8_2) + + p = self.up_1(p, f_2) + + inter_s4 = self.final_56(p) + r_inter_s4 = F.interpolate( + inter_s4, scale_factor=4, mode="bilinear", align_corners=False + ) + r_inter_tanh_s4 = torch.tanh(r_inter_s4) + + images["pred_28_2"] = torch.sigmoid(r_inter_s8_2) + images["out_28_2"] = r_inter_s8_2 + images["pred_56"] = torch.sigmoid(r_inter_s4) + images["out_56"] = r_inter_s4 + else: + r_inter_tanh_s8_2 = inter_s8 + r_inter_tanh_s4 = inter_s4 + + """ + Third iteration, s1 output + """ + p = torch.cat((x, seg, r_inter_tanh_s8_2, r_inter_tanh_s4), 1) + + f, f_1, f_2 = self.feats(p) + p = self.psp(f) + inter_s8_3 = self.final_28(p) + r_inter_s8_3 = F.interpolate( + inter_s8_3, scale_factor=8, mode="bilinear", align_corners=False + ) + + p = self.up_1(p, f_2) + inter_s4_2 = self.final_56(p) + r_inter_s4_2 = F.interpolate( + inter_s4_2, scale_factor=4, mode="bilinear", align_corners=False + ) + p = self.up_2(p, f_1) + p = self.up_3(p, x) + + """ + Final output + """ + p = F.relu(self.final_11(torch.cat([p, x], 1)), inplace=True) + p = self.final_21(p) + + pred_224 = torch.sigmoid(p) + + images["pred_224"] = pred_224 + images["out_224"] = p + images["pred_28_3"] = torch.sigmoid(r_inter_s8_3) + images["pred_56_2"] = torch.sigmoid(r_inter_s4_2) + images["out_28_3"] = r_inter_s8_3 + images["out_56_2"] = r_inter_s4_2 + + return images diff --git a/carvekit/ml/arch/cascadepsp/utils.py b/carvekit/ml/arch/cascadepsp/utils.py new file mode 100644 index 0000000..f63a524 --- /dev/null +++ b/carvekit/ml/arch/cascadepsp/utils.py @@ -0,0 +1,166 @@ +import torch +import torch.nn.functional as F + + +def resize_max_side(im, size, method): + h, w = im.shape[-2:] + max_side = max(h, w) + ratio = size / max_side + if method in ["bilinear", "bicubic"]: + return F.interpolate(im, scale_factor=ratio, mode=method, align_corners=False) + else: + return F.interpolate(im, scale_factor=ratio, mode=method) + + +def process_high_res_im(model, im, seg, L=900): + stride = L // 2 + + _, _, h, w = seg.shape + if max(h, w) > L: + im_small = resize_max_side(im, L, "area") + seg_small = resize_max_side(seg, L, "area") + elif max(h, w) < L: + im_small = resize_max_side(im, L, "bicubic") + seg_small = resize_max_side(seg, L, "bilinear") + else: + im_small = im + seg_small = seg + + images = model.safe_forward(im_small, seg_small) + + pred_224 = images["pred_224"] + pred_56 = images["pred_56_2"] + + for new_size in [max(h, w)]: + im_small = resize_max_side(im, new_size, "area") + seg_small = resize_max_side(seg, new_size, "area") + _, _, h, w = seg_small.shape + + combined_224 = torch.zeros_like(seg_small) + combined_weight = torch.zeros_like(seg_small) + + r_pred_224 = ( + F.interpolate(pred_224, size=(h, w), mode="bilinear", align_corners=False) + > 0.5 + ).float() * 2 - 1 + r_pred_56 = ( + F.interpolate(pred_56, size=(h, w), mode="bilinear", align_corners=False) + * 2 + - 1 + ) + + padding = 16 + step_size = stride - padding * 2 + step_len = L + + used_start_idx = {} + for x_idx in range((w) // step_size + 1): + for y_idx in range((h) // step_size + 1): + + start_x = x_idx * step_size + start_y = y_idx * step_size + end_x = start_x + step_len + end_y = start_y + step_len + + # Shift when required + if end_y > h: + end_y = h + start_y = h - step_len + if end_x > w: + end_x = w + start_x = w - step_len + + # Bound x/y range + start_x = max(0, start_x) + start_y = max(0, start_y) + end_x = min(w, end_x) + end_y = min(h, end_y) + + # The same crop might appear twice due to bounding/shifting + start_idx = start_y * w + start_x + if start_idx in used_start_idx: + continue + else: + used_start_idx[start_idx] = True + + # Take crop + im_part = im_small[:, :, start_y:end_y, start_x:end_x] + seg_224_part = r_pred_224[:, :, start_y:end_y, start_x:end_x] + seg_56_part = r_pred_56[:, :, start_y:end_y, start_x:end_x] + + # Skip when it is not an interesting crop anyway + seg_part_norm = (seg_224_part > 0).float() + high_thres = 0.9 + low_thres = 0.1 + if (seg_part_norm.mean() > high_thres) or ( + seg_part_norm.mean() < low_thres + ): + continue + grid_images = model.safe_forward(im_part, seg_224_part, seg_56_part) + grid_pred_224 = grid_images["pred_224"] + + # Padding + pred_sx = pred_sy = 0 + pred_ex = step_len + pred_ey = step_len + + if start_x != 0: + start_x += padding + pred_sx += padding + if start_y != 0: + start_y += padding + pred_sy += padding + if end_x != w: + end_x -= padding + pred_ex -= padding + if end_y != h: + end_y -= padding + pred_ey -= padding + + combined_224[:, :, start_y:end_y, start_x:end_x] += grid_pred_224[ + :, :, pred_sy:pred_ey, pred_sx:pred_ex + ] + + del grid_pred_224 + + # Used for averaging + combined_weight[:, :, start_y:end_y, start_x:end_x] += 1 + + # Final full resolution output + seg_norm = r_pred_224 / 2 + 0.5 + pred_224 = combined_224 / combined_weight + pred_224 = torch.where(combined_weight == 0, seg_norm, pred_224) + + _, _, h, w = seg.shape + images = {} + images["pred_224"] = F.interpolate( + pred_224, size=(h, w), mode="bilinear", align_corners=True + ) + + return images["pred_224"] + + +def process_im_single_pass(model, im, seg, L=900): + """ + A single pass version, aka global step only. + """ + + _, _, h, w = im.shape + if max(h, w) < L: + im = resize_max_side(im, L, "bicubic") + seg = resize_max_side(seg, L, "bilinear") + + if max(h, w) > L: + im = resize_max_side(im, L, "area") + seg = resize_max_side(seg, L, "area") + + images = model.safe_forward(im, seg) + + if max(h, w) < L: + images["pred_224"] = F.interpolate(images["pred_224"], size=(h, w), mode="area") + elif max(h, w) > L: + images["pred_224"] = F.interpolate( + images["pred_224"], size=(h, w), mode="bilinear", align_corners=True + ) + + return images["pred_224"] diff --git a/carvekit/ml/arch/yolov4/__init__.py b/carvekit/ml/arch/yolov4/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/carvekit/ml/arch/yolov4/models.py b/carvekit/ml/arch/yolov4/models.py new file mode 100644 index 0000000..af094f2 --- /dev/null +++ b/carvekit/ml/arch/yolov4/models.py @@ -0,0 +1,557 @@ +""" +Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/Tianxiaomo/pytorch-YOLOv4 +License: Apache License 2.0 +""" +import torch +from torch import nn +import torch.nn.functional as F +from carvekit.ml.arch.yolov4.yolo_layer import YoloLayer + + +def get_region_boxes(boxes_and_confs): + # print('Getting boxes from boxes and confs ...') + + boxes_list = [] + confs_list = [] + + for item in boxes_and_confs: + boxes_list.append(item[0]) + confs_list.append(item[1]) + + # boxes: [batch, num1 + num2 + num3, 1, 4] + # confs: [batch, num1 + num2 + num3, num_classes] + boxes = torch.cat(boxes_list, dim=1) + confs = torch.cat(confs_list, dim=1) + + return [boxes, confs] + + +class Mish(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x * (torch.tanh(torch.nn.functional.softplus(x))) + return x + + +class Upsample(nn.Module): + def __init__(self): + super(Upsample, self).__init__() + + def forward(self, x, target_size, inference=False): + assert x.data.dim() == 4 + # _, _, tH, tW = target_size + + if inference: + + # B = x.data.size(0) + # C = x.data.size(1) + # H = x.data.size(2) + # W = x.data.size(3) + + return ( + x.view(x.size(0), x.size(1), x.size(2), 1, x.size(3), 1) + .expand( + x.size(0), + x.size(1), + x.size(2), + target_size[2] // x.size(2), + x.size(3), + target_size[3] // x.size(3), + ) + .contiguous() + .view(x.size(0), x.size(1), target_size[2], target_size[3]) + ) + else: + return F.interpolate( + x, size=(target_size[2], target_size[3]), mode="nearest" + ) + + +class Conv_Bn_Activation(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + activation, + bn=True, + bias=False, + ): + super().__init__() + pad = (kernel_size - 1) // 2 + + self.conv = nn.ModuleList() + if bias: + self.conv.append( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad) + ) + else: + self.conv.append( + nn.Conv2d( + in_channels, out_channels, kernel_size, stride, pad, bias=False + ) + ) + if bn: + self.conv.append(nn.BatchNorm2d(out_channels)) + if activation == "mish": + self.conv.append(Mish()) + elif activation == "relu": + self.conv.append(nn.ReLU(inplace=True)) + elif activation == "leaky": + self.conv.append(nn.LeakyReLU(0.1, inplace=True)) + elif activation == "linear": + pass + else: + raise Exception("activation error") + + def forward(self, x): + for l in self.conv: + x = l(x) + return x + + +class ResBlock(nn.Module): + """ + Sequential residual blocks each of which consists of \ + two convolution layers. + Args: + ch (int): number of input and output channels. + nblocks (int): number of residual blocks. + shortcut (bool): if True, residual tensor addition is enabled. + """ + + def __init__(self, ch, nblocks=1, shortcut=True): + super().__init__() + self.shortcut = shortcut + self.module_list = nn.ModuleList() + for i in range(nblocks): + resblock_one = nn.ModuleList() + resblock_one.append(Conv_Bn_Activation(ch, ch, 1, 1, "mish")) + resblock_one.append(Conv_Bn_Activation(ch, ch, 3, 1, "mish")) + self.module_list.append(resblock_one) + + def forward(self, x): + for module in self.module_list: + h = x + for res in module: + h = res(h) + x = x + h if self.shortcut else h + return x + + +class DownSample1(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = Conv_Bn_Activation(3, 32, 3, 1, "mish") + + self.conv2 = Conv_Bn_Activation(32, 64, 3, 2, "mish") + self.conv3 = Conv_Bn_Activation(64, 64, 1, 1, "mish") + # [route] + # layers = -2 + self.conv4 = Conv_Bn_Activation(64, 64, 1, 1, "mish") + + self.conv5 = Conv_Bn_Activation(64, 32, 1, 1, "mish") + self.conv6 = Conv_Bn_Activation(32, 64, 3, 1, "mish") + # [shortcut] + # from=-3 + # activation = linear + + self.conv7 = Conv_Bn_Activation(64, 64, 1, 1, "mish") + # [route] + # layers = -1, -7 + self.conv8 = Conv_Bn_Activation(128, 64, 1, 1, "mish") + + def forward(self, input): + x1 = self.conv1(input) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + # route -2 + x4 = self.conv4(x2) + x5 = self.conv5(x4) + x6 = self.conv6(x5) + # shortcut -3 + x6 = x6 + x4 + + x7 = self.conv7(x6) + # [route] + # layers = -1, -7 + x7 = torch.cat([x7, x3], dim=1) + x8 = self.conv8(x7) + return x8 + + +class DownSample2(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = Conv_Bn_Activation(64, 128, 3, 2, "mish") + self.conv2 = Conv_Bn_Activation(128, 64, 1, 1, "mish") + # r -2 + self.conv3 = Conv_Bn_Activation(128, 64, 1, 1, "mish") + + self.resblock = ResBlock(ch=64, nblocks=2) + + # s -3 + self.conv4 = Conv_Bn_Activation(64, 64, 1, 1, "mish") + # r -1 -10 + self.conv5 = Conv_Bn_Activation(128, 128, 1, 1, "mish") + + def forward(self, input): + x1 = self.conv1(input) + x2 = self.conv2(x1) + x3 = self.conv3(x1) + + r = self.resblock(x3) + x4 = self.conv4(r) + + x4 = torch.cat([x4, x2], dim=1) + x5 = self.conv5(x4) + return x5 + + +class DownSample3(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = Conv_Bn_Activation(128, 256, 3, 2, "mish") + self.conv2 = Conv_Bn_Activation(256, 128, 1, 1, "mish") + self.conv3 = Conv_Bn_Activation(256, 128, 1, 1, "mish") + + self.resblock = ResBlock(ch=128, nblocks=8) + self.conv4 = Conv_Bn_Activation(128, 128, 1, 1, "mish") + self.conv5 = Conv_Bn_Activation(256, 256, 1, 1, "mish") + + def forward(self, input): + x1 = self.conv1(input) + x2 = self.conv2(x1) + x3 = self.conv3(x1) + + r = self.resblock(x3) + x4 = self.conv4(r) + + x4 = torch.cat([x4, x2], dim=1) + x5 = self.conv5(x4) + return x5 + + +class DownSample4(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = Conv_Bn_Activation(256, 512, 3, 2, "mish") + self.conv2 = Conv_Bn_Activation(512, 256, 1, 1, "mish") + self.conv3 = Conv_Bn_Activation(512, 256, 1, 1, "mish") + + self.resblock = ResBlock(ch=256, nblocks=8) + self.conv4 = Conv_Bn_Activation(256, 256, 1, 1, "mish") + self.conv5 = Conv_Bn_Activation(512, 512, 1, 1, "mish") + + def forward(self, input): + x1 = self.conv1(input) + x2 = self.conv2(x1) + x3 = self.conv3(x1) + + r = self.resblock(x3) + x4 = self.conv4(r) + + x4 = torch.cat([x4, x2], dim=1) + x5 = self.conv5(x4) + return x5 + + +class DownSample5(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = Conv_Bn_Activation(512, 1024, 3, 2, "mish") + self.conv2 = Conv_Bn_Activation(1024, 512, 1, 1, "mish") + self.conv3 = Conv_Bn_Activation(1024, 512, 1, 1, "mish") + + self.resblock = ResBlock(ch=512, nblocks=4) + self.conv4 = Conv_Bn_Activation(512, 512, 1, 1, "mish") + self.conv5 = Conv_Bn_Activation(1024, 1024, 1, 1, "mish") + + def forward(self, input): + x1 = self.conv1(input) + x2 = self.conv2(x1) + x3 = self.conv3(x1) + + r = self.resblock(x3) + x4 = self.conv4(r) + + x4 = torch.cat([x4, x2], dim=1) + x5 = self.conv5(x4) + return x5 + + +class Neck(nn.Module): + def __init__(self, inference=False): + super().__init__() + self.inference = inference + + self.conv1 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky") + self.conv2 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky") + self.conv3 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky") + # SPP + self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5 // 2) + self.maxpool2 = nn.MaxPool2d(kernel_size=9, stride=1, padding=9 // 2) + self.maxpool3 = nn.MaxPool2d(kernel_size=13, stride=1, padding=13 // 2) + + # R -1 -3 -5 -6 + # SPP + self.conv4 = Conv_Bn_Activation(2048, 512, 1, 1, "leaky") + self.conv5 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky") + self.conv6 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky") + self.conv7 = Conv_Bn_Activation(512, 256, 1, 1, "leaky") + # UP + self.upsample1 = Upsample() + # R 85 + self.conv8 = Conv_Bn_Activation(512, 256, 1, 1, "leaky") + # R -1 -3 + self.conv9 = Conv_Bn_Activation(512, 256, 1, 1, "leaky") + self.conv10 = Conv_Bn_Activation(256, 512, 3, 1, "leaky") + self.conv11 = Conv_Bn_Activation(512, 256, 1, 1, "leaky") + self.conv12 = Conv_Bn_Activation(256, 512, 3, 1, "leaky") + self.conv13 = Conv_Bn_Activation(512, 256, 1, 1, "leaky") + self.conv14 = Conv_Bn_Activation(256, 128, 1, 1, "leaky") + # UP + self.upsample2 = Upsample() + # R 54 + self.conv15 = Conv_Bn_Activation(256, 128, 1, 1, "leaky") + # R -1 -3 + self.conv16 = Conv_Bn_Activation(256, 128, 1, 1, "leaky") + self.conv17 = Conv_Bn_Activation(128, 256, 3, 1, "leaky") + self.conv18 = Conv_Bn_Activation(256, 128, 1, 1, "leaky") + self.conv19 = Conv_Bn_Activation(128, 256, 3, 1, "leaky") + self.conv20 = Conv_Bn_Activation(256, 128, 1, 1, "leaky") + + def forward(self, input, downsample4, downsample3, inference=False): + x1 = self.conv1(input) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + # SPP + m1 = self.maxpool1(x3) + m2 = self.maxpool2(x3) + m3 = self.maxpool3(x3) + spp = torch.cat([m3, m2, m1, x3], dim=1) + # SPP end + x4 = self.conv4(spp) + x5 = self.conv5(x4) + x6 = self.conv6(x5) + x7 = self.conv7(x6) + # UP + up = self.upsample1(x7, downsample4.size(), self.inference) + # R 85 + x8 = self.conv8(downsample4) + # R -1 -3 + x8 = torch.cat([x8, up], dim=1) + + x9 = self.conv9(x8) + x10 = self.conv10(x9) + x11 = self.conv11(x10) + x12 = self.conv12(x11) + x13 = self.conv13(x12) + x14 = self.conv14(x13) + + # UP + up = self.upsample2(x14, downsample3.size(), self.inference) + # R 54 + x15 = self.conv15(downsample3) + # R -1 -3 + x15 = torch.cat([x15, up], dim=1) + + x16 = self.conv16(x15) + x17 = self.conv17(x16) + x18 = self.conv18(x17) + x19 = self.conv19(x18) + x20 = self.conv20(x19) + return x20, x13, x6 + + +class Yolov4Head(nn.Module): + def __init__(self, output_ch, n_classes, inference=False): + super().__init__() + self.inference = inference + + self.conv1 = Conv_Bn_Activation(128, 256, 3, 1, "leaky") + self.conv2 = Conv_Bn_Activation( + 256, output_ch, 1, 1, "linear", bn=False, bias=True + ) + + self.yolo1 = YoloLayer( + anchor_mask=[0, 1, 2], + num_classes=n_classes, + anchors=[ + 12, + 16, + 19, + 36, + 40, + 28, + 36, + 75, + 76, + 55, + 72, + 146, + 142, + 110, + 192, + 243, + 459, + 401, + ], + num_anchors=9, + stride=8, + ) + + # R -4 + self.conv3 = Conv_Bn_Activation(128, 256, 3, 2, "leaky") + + # R -1 -16 + self.conv4 = Conv_Bn_Activation(512, 256, 1, 1, "leaky") + self.conv5 = Conv_Bn_Activation(256, 512, 3, 1, "leaky") + self.conv6 = Conv_Bn_Activation(512, 256, 1, 1, "leaky") + self.conv7 = Conv_Bn_Activation(256, 512, 3, 1, "leaky") + self.conv8 = Conv_Bn_Activation(512, 256, 1, 1, "leaky") + self.conv9 = Conv_Bn_Activation(256, 512, 3, 1, "leaky") + self.conv10 = Conv_Bn_Activation( + 512, output_ch, 1, 1, "linear", bn=False, bias=True + ) + + self.yolo2 = YoloLayer( + anchor_mask=[3, 4, 5], + num_classes=n_classes, + anchors=[ + 12, + 16, + 19, + 36, + 40, + 28, + 36, + 75, + 76, + 55, + 72, + 146, + 142, + 110, + 192, + 243, + 459, + 401, + ], + num_anchors=9, + stride=16, + ) + + # R -4 + self.conv11 = Conv_Bn_Activation(256, 512, 3, 2, "leaky") + + # R -1 -37 + self.conv12 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky") + self.conv13 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky") + self.conv14 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky") + self.conv15 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky") + self.conv16 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky") + self.conv17 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky") + self.conv18 = Conv_Bn_Activation( + 1024, output_ch, 1, 1, "linear", bn=False, bias=True + ) + + self.yolo3 = YoloLayer( + anchor_mask=[6, 7, 8], + num_classes=n_classes, + anchors=[ + 12, + 16, + 19, + 36, + 40, + 28, + 36, + 75, + 76, + 55, + 72, + 146, + 142, + 110, + 192, + 243, + 459, + 401, + ], + num_anchors=9, + stride=32, + ) + + def forward(self, input1, input2, input3): + x1 = self.conv1(input1) + x2 = self.conv2(x1) + + x3 = self.conv3(input1) + # R -1 -16 + x3 = torch.cat([x3, input2], dim=1) + x4 = self.conv4(x3) + x5 = self.conv5(x4) + x6 = self.conv6(x5) + x7 = self.conv7(x6) + x8 = self.conv8(x7) + x9 = self.conv9(x8) + x10 = self.conv10(x9) + + # R -4 + x11 = self.conv11(x8) + # R -1 -37 + x11 = torch.cat([x11, input3], dim=1) + + x12 = self.conv12(x11) + x13 = self.conv13(x12) + x14 = self.conv14(x13) + x15 = self.conv15(x14) + x16 = self.conv16(x15) + x17 = self.conv17(x16) + x18 = self.conv18(x17) + + if self.inference: + y1 = self.yolo1(x2) + y2 = self.yolo2(x10) + y3 = self.yolo3(x18) + + return get_region_boxes([y1, y2, y3]) + + else: + return [x2, x10, x18] + + +class Yolov4(nn.Module): + def __init__(self, n_classes=80, inference=False): + super().__init__() + + output_ch = (4 + 1 + n_classes) * 3 + + # backbone + self.down1 = DownSample1() + self.down2 = DownSample2() + self.down3 = DownSample3() + self.down4 = DownSample4() + self.down5 = DownSample5() + # neck + self.neek = Neck(inference) + + # head + self.head = Yolov4Head(output_ch, n_classes, inference) + + def forward(self, input): + d1 = self.down1(input) + d2 = self.down2(d1) + d3 = self.down3(d2) + d4 = self.down4(d3) + d5 = self.down5(d4) + + x20, x13, x6 = self.neek(d5, d4, d3) + + output = self.head(x20, x13, x6) + return output diff --git a/carvekit/ml/arch/yolov4/utils.py b/carvekit/ml/arch/yolov4/utils.py new file mode 100644 index 0000000..53cc9e9 --- /dev/null +++ b/carvekit/ml/arch/yolov4/utils.py @@ -0,0 +1,105 @@ +import numpy as np + + +def nms_cpu(boxes, confs, nms_thresh=0.5, min_mode=False): + # print(boxes.shape) + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1) * (y2 - y1) + order = confs.argsort()[::-1] + + keep = [] + while order.size > 0: + idx_self = order[0] + idx_other = order[1:] + + keep.append(idx_self) + + xx1 = np.maximum(x1[idx_self], x1[idx_other]) + yy1 = np.maximum(y1[idx_self], y1[idx_other]) + xx2 = np.minimum(x2[idx_self], x2[idx_other]) + yy2 = np.minimum(y2[idx_self], y2[idx_other]) + + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + inter = w * h + + if min_mode: + over = inter / np.minimum(areas[order[0]], areas[order[1:]]) + else: + over = inter / (areas[order[0]] + areas[order[1:]] - inter) + + inds = np.where(over <= nms_thresh)[0] + order = order[inds + 1] + + return np.array(keep) + + +def post_processing(conf_thresh, nms_thresh, output): + # anchors = [12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401] + # num_anchors = 9 + # anchor_masks = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + # strides = [8, 16, 32] + # anchor_step = len(anchors) // num_anchors + + # [batch, num, 1, 4] + box_array = output[0] + # [batch, num, num_classes] + confs = output[1] + + if type(box_array).__name__ != "ndarray": + box_array = box_array.cpu().detach().numpy() + confs = confs.cpu().detach().numpy() + + num_classes = confs.shape[2] + + # [batch, num, 4] + box_array = box_array[:, :, 0] + + # [batch, num, num_classes] --> [batch, num] + max_conf = np.max(confs, axis=2) + max_id = np.argmax(confs, axis=2) + + bboxes_batch = [] + for i in range(box_array.shape[0]): + + argwhere = max_conf[i] > conf_thresh + l_box_array = box_array[i, argwhere, :] + l_max_conf = max_conf[i, argwhere] + l_max_id = max_id[i, argwhere] + + bboxes = [] + # nms for each class + for j in range(num_classes): + + cls_argwhere = l_max_id == j + ll_box_array = l_box_array[cls_argwhere, :] + ll_max_conf = l_max_conf[cls_argwhere] + ll_max_id = l_max_id[cls_argwhere] + + keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh) + + if keep.size > 0: + ll_box_array = ll_box_array[keep, :] + ll_max_conf = ll_max_conf[keep] + ll_max_id = ll_max_id[keep] + + for k in range(ll_box_array.shape[0]): + bboxes.append( + [ + ll_box_array[k, 0], + ll_box_array[k, 1], + ll_box_array[k, 2], + ll_box_array[k, 3], + ll_max_conf[k], + ll_max_conf[k], + ll_max_id[k], + ] + ) + + bboxes_batch.append(bboxes) + + return bboxes_batch diff --git a/carvekit/ml/arch/yolov4/yolo_layer.py b/carvekit/ml/arch/yolov4/yolo_layer.py new file mode 100644 index 0000000..637f659 --- /dev/null +++ b/carvekit/ml/arch/yolov4/yolo_layer.py @@ -0,0 +1,416 @@ +""" +Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/Tianxiaomo/pytorch-YOLOv4 +License: Apache License 2.0 +""" +import numpy as np +import torch +import torch.nn as nn + + +def yolo_forward( + output, + conf_thresh, + num_classes, + anchors, + num_anchors, + scale_x_y, + only_objectness=1, + validation=False, +): + # Output would be invalid if it does not satisfy this assert + # assert (output.size(1) == (5 + num_classes) * num_anchors) + + # print(output.size()) + + # Slice the second dimension (channel) of output into: + # [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ] + # And then into + # bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ] + batch = output.size(0) + H = output.size(2) + W = output.size(3) + + bxy_list = [] + bwh_list = [] + det_confs_list = [] + cls_confs_list = [] + + for i in range(num_anchors): + begin = i * (5 + num_classes) + end = (i + 1) * (5 + num_classes) + + bxy_list.append(output[:, begin : begin + 2]) + bwh_list.append(output[:, begin + 2 : begin + 4]) + det_confs_list.append(output[:, begin + 4 : begin + 5]) + cls_confs_list.append(output[:, begin + 5 : end]) + + # Shape: [batch, num_anchors * 2, H, W] + bxy = torch.cat(bxy_list, dim=1) + # Shape: [batch, num_anchors * 2, H, W] + bwh = torch.cat(bwh_list, dim=1) + + # Shape: [batch, num_anchors, H, W] + det_confs = torch.cat(det_confs_list, dim=1) + # Shape: [batch, num_anchors * H * W] + det_confs = det_confs.view(batch, num_anchors * H * W) + + # Shape: [batch, num_anchors * num_classes, H, W] + cls_confs = torch.cat(cls_confs_list, dim=1) + # Shape: [batch, num_anchors, num_classes, H * W] + cls_confs = cls_confs.view(batch, num_anchors, num_classes, H * W) + # Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes] + cls_confs = cls_confs.permute(0, 1, 3, 2).reshape( + batch, num_anchors * H * W, num_classes + ) + + # Apply sigmoid(), exp() and softmax() to slices + # + bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1) + bwh = torch.exp(bwh) + det_confs = torch.sigmoid(det_confs) + cls_confs = torch.sigmoid(cls_confs) + + # Prepare C-x, C-y, P-w, P-h (None of them are torch related) + grid_x = np.expand_dims( + np.expand_dims( + np.expand_dims(np.linspace(0, W - 1, W), axis=0).repeat(H, 0), axis=0 + ), + axis=0, + ) + grid_y = np.expand_dims( + np.expand_dims( + np.expand_dims(np.linspace(0, H - 1, H), axis=1).repeat(W, 1), axis=0 + ), + axis=0, + ) + # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1) + # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W) + + anchor_w = [] + anchor_h = [] + for i in range(num_anchors): + anchor_w.append(anchors[i * 2]) + anchor_h.append(anchors[i * 2 + 1]) + + device = None + cuda_check = output.is_cuda + if cuda_check: + device = output.get_device() + + bx_list = [] + by_list = [] + bw_list = [] + bh_list = [] + + # Apply C-x, C-y, P-w, P-h + for i in range(num_anchors): + ii = i * 2 + # Shape: [batch, 1, H, W] + bx = bxy[:, ii : ii + 1] + torch.tensor( + grid_x, device=device, dtype=torch.float32 + ) # grid_x.to(device=device, dtype=torch.float32) + # Shape: [batch, 1, H, W] + by = bxy[:, ii + 1 : ii + 2] + torch.tensor( + grid_y, device=device, dtype=torch.float32 + ) # grid_y.to(device=device, dtype=torch.float32) + # Shape: [batch, 1, H, W] + bw = bwh[:, ii : ii + 1] * anchor_w[i] + # Shape: [batch, 1, H, W] + bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i] + + bx_list.append(bx) + by_list.append(by) + bw_list.append(bw) + bh_list.append(bh) + + ######################################## + # Figure out bboxes from slices # + ######################################## + + # Shape: [batch, num_anchors, H, W] + bx = torch.cat(bx_list, dim=1) + # Shape: [batch, num_anchors, H, W] + by = torch.cat(by_list, dim=1) + # Shape: [batch, num_anchors, H, W] + bw = torch.cat(bw_list, dim=1) + # Shape: [batch, num_anchors, H, W] + bh = torch.cat(bh_list, dim=1) + + # Shape: [batch, 2 * num_anchors, H, W] + bx_bw = torch.cat((bx, bw), dim=1) + # Shape: [batch, 2 * num_anchors, H, W] + by_bh = torch.cat((by, bh), dim=1) + + # normalize coordinates to [0, 1] + bx_bw /= W + by_bh /= H + + # Shape: [batch, num_anchors * H * W, 1] + bx = bx_bw[:, :num_anchors].view(batch, num_anchors * H * W, 1) + by = by_bh[:, :num_anchors].view(batch, num_anchors * H * W, 1) + bw = bx_bw[:, num_anchors:].view(batch, num_anchors * H * W, 1) + bh = by_bh[:, num_anchors:].view(batch, num_anchors * H * W, 1) + + bx1 = bx - bw * 0.5 + by1 = by - bh * 0.5 + bx2 = bx1 + bw + by2 = by1 + bh + + # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4] + boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view( + batch, num_anchors * H * W, 1, 4 + ) + # boxes = boxes.repeat(1, 1, num_classes, 1) + + # boxes: [batch, num_anchors * H * W, 1, 4] + # cls_confs: [batch, num_anchors * H * W, num_classes] + # det_confs: [batch, num_anchors * H * W] + + det_confs = det_confs.view(batch, num_anchors * H * W, 1) + confs = cls_confs * det_confs + + # boxes: [batch, num_anchors * H * W, 1, 4] + # confs: [batch, num_anchors * H * W, num_classes] + + return boxes, confs + + +def yolo_forward_dynamic( + output, + conf_thresh, + num_classes, + anchors, + num_anchors, + scale_x_y, + only_objectness=1, + validation=False, +): + # Output would be invalid if it does not satisfy this assert + # assert (output.size(1) == (5 + num_classes) * num_anchors) + + # print(output.size()) + + # Slice the second dimension (channel) of output into: + # [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ] + # And then into + # bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ] + # batch = output.size(0) + # H = output.size(2) + # W = output.size(3) + + bxy_list = [] + bwh_list = [] + det_confs_list = [] + cls_confs_list = [] + + for i in range(num_anchors): + begin = i * (5 + num_classes) + end = (i + 1) * (5 + num_classes) + + bxy_list.append(output[:, begin : begin + 2]) + bwh_list.append(output[:, begin + 2 : begin + 4]) + det_confs_list.append(output[:, begin + 4 : begin + 5]) + cls_confs_list.append(output[:, begin + 5 : end]) + + # Shape: [batch, num_anchors * 2, H, W] + bxy = torch.cat(bxy_list, dim=1) + # Shape: [batch, num_anchors * 2, H, W] + bwh = torch.cat(bwh_list, dim=1) + + # Shape: [batch, num_anchors, H, W] + det_confs = torch.cat(det_confs_list, dim=1) + # Shape: [batch, num_anchors * H * W] + det_confs = det_confs.view( + output.size(0), num_anchors * output.size(2) * output.size(3) + ) + + # Shape: [batch, num_anchors * num_classes, H, W] + cls_confs = torch.cat(cls_confs_list, dim=1) + # Shape: [batch, num_anchors, num_classes, H * W] + cls_confs = cls_confs.view( + output.size(0), num_anchors, num_classes, output.size(2) * output.size(3) + ) + # Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes] + cls_confs = cls_confs.permute(0, 1, 3, 2).reshape( + output.size(0), num_anchors * output.size(2) * output.size(3), num_classes + ) + + # Apply sigmoid(), exp() and softmax() to slices + # + bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1) + bwh = torch.exp(bwh) + det_confs = torch.sigmoid(det_confs) + cls_confs = torch.sigmoid(cls_confs) + + # Prepare C-x, C-y, P-w, P-h (None of them are torch related) + grid_x = np.expand_dims( + np.expand_dims( + np.expand_dims( + np.linspace(0, output.size(3) - 1, output.size(3)), axis=0 + ).repeat(output.size(2), 0), + axis=0, + ), + axis=0, + ) + grid_y = np.expand_dims( + np.expand_dims( + np.expand_dims( + np.linspace(0, output.size(2) - 1, output.size(2)), axis=1 + ).repeat(output.size(3), 1), + axis=0, + ), + axis=0, + ) + # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1) + # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W) + + anchor_w = [] + anchor_h = [] + for i in range(num_anchors): + anchor_w.append(anchors[i * 2]) + anchor_h.append(anchors[i * 2 + 1]) + + device = None + cuda_check = output.is_cuda + if cuda_check: + device = output.get_device() + + bx_list = [] + by_list = [] + bw_list = [] + bh_list = [] + + # Apply C-x, C-y, P-w, P-h + for i in range(num_anchors): + ii = i * 2 + # Shape: [batch, 1, H, W] + bx = bxy[:, ii : ii + 1] + torch.tensor( + grid_x, device=device, dtype=torch.float32 + ) # grid_x.to(device=device, dtype=torch.float32) + # Shape: [batch, 1, H, W] + by = bxy[:, ii + 1 : ii + 2] + torch.tensor( + grid_y, device=device, dtype=torch.float32 + ) # grid_y.to(device=device, dtype=torch.float32) + # Shape: [batch, 1, H, W] + bw = bwh[:, ii : ii + 1] * anchor_w[i] + # Shape: [batch, 1, H, W] + bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i] + + bx_list.append(bx) + by_list.append(by) + bw_list.append(bw) + bh_list.append(bh) + + ######################################## + # Figure out bboxes from slices # + ######################################## + + # Shape: [batch, num_anchors, H, W] + bx = torch.cat(bx_list, dim=1) + # Shape: [batch, num_anchors, H, W] + by = torch.cat(by_list, dim=1) + # Shape: [batch, num_anchors, H, W] + bw = torch.cat(bw_list, dim=1) + # Shape: [batch, num_anchors, H, W] + bh = torch.cat(bh_list, dim=1) + + # Shape: [batch, 2 * num_anchors, H, W] + bx_bw = torch.cat((bx, bw), dim=1) + # Shape: [batch, 2 * num_anchors, H, W] + by_bh = torch.cat((by, bh), dim=1) + + # normalize coordinates to [0, 1] + bx_bw /= output.size(3) + by_bh /= output.size(2) + + # Shape: [batch, num_anchors * H * W, 1] + bx = bx_bw[:, :num_anchors].view( + output.size(0), num_anchors * output.size(2) * output.size(3), 1 + ) + by = by_bh[:, :num_anchors].view( + output.size(0), num_anchors * output.size(2) * output.size(3), 1 + ) + bw = bx_bw[:, num_anchors:].view( + output.size(0), num_anchors * output.size(2) * output.size(3), 1 + ) + bh = by_bh[:, num_anchors:].view( + output.size(0), num_anchors * output.size(2) * output.size(3), 1 + ) + + bx1 = bx - bw * 0.5 + by1 = by - bh * 0.5 + bx2 = bx1 + bw + by2 = by1 + bh + + # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4] + boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view( + output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4 + ) + # boxes = boxes.repeat(1, 1, num_classes, 1) + + # boxes: [batch, num_anchors * H * W, 1, 4] + # cls_confs: [batch, num_anchors * H * W, num_classes] + # det_confs: [batch, num_anchors * H * W] + + det_confs = det_confs.view( + output.size(0), num_anchors * output.size(2) * output.size(3), 1 + ) + confs = cls_confs * det_confs + + # boxes: [batch, num_anchors * H * W, 1, 4] + # confs: [batch, num_anchors * H * W, num_classes] + + return boxes, confs + + +class YoloLayer(nn.Module): + """Yolo layer + model_out: while inference,is post-processing inside or outside the model + true:outside + """ + + def __init__( + self, + anchor_mask=[], + num_classes=0, + anchors=[], + num_anchors=1, + stride=32, + model_out=False, + ): + super(YoloLayer, self).__init__() + self.anchor_mask = anchor_mask + self.num_classes = num_classes + self.anchors = anchors + self.num_anchors = num_anchors + self.anchor_step = len(anchors) // num_anchors + self.coord_scale = 1 + self.noobject_scale = 1 + self.object_scale = 5 + self.class_scale = 1 + self.thresh = 0.6 + self.stride = stride + self.seen = 0 + self.scale_x_y = 1 + + self.model_out = model_out + + def forward(self, output, target=None): + if self.training: + return output + masked_anchors = [] + for m in self.anchor_mask: + masked_anchors += self.anchors[ + m * self.anchor_step : (m + 1) * self.anchor_step + ] + masked_anchors = [anchor / self.stride for anchor in masked_anchors] + + return yolo_forward_dynamic( + output, + self.thresh, + self.num_classes, + masked_anchors, + len(self.anchor_mask), + scale_x_y=self.scale_x_y, + ) diff --git a/carvekit/ml/files/models_loc.py b/carvekit/ml/files/models_loc.py index 45f9a56..cf43ab5 100644 --- a/carvekit/ml/files/models_loc.py +++ b/carvekit/ml/files/models_loc.py @@ -1,6 +1,8 @@ """ Source url: https://github.com/OPHoperHPO/image-background-remove-tool + Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. + License: Apache License 2.0 """ import pathlib @@ -12,7 +14,7 @@ def u2net_full_pretrained() -> pathlib.Path: """Returns u2net pretrained model location Returns: - pathlib.Path to model location + pathlib.Path: model location """ return downloader("u2net.pth") @@ -21,7 +23,7 @@ def basnet_pretrained() -> pathlib.Path: """Returns basnet pretrained model location Returns: - pathlib.Path to model location + pathlib.Path: model location """ return downloader("basnet.pth") @@ -30,7 +32,7 @@ def deeplab_pretrained() -> pathlib.Path: """Returns basnet pretrained model location Returns: - pathlib.Path to model location + pathlib.Path: model location """ return downloader("deeplab.pth") @@ -39,7 +41,7 @@ def fba_pretrained() -> pathlib.Path: """Returns basnet pretrained model location Returns: - pathlib.Path to model location + pathlib.Path: model location """ return downloader("fba_matting.pth") @@ -48,18 +50,54 @@ def tracer_b7_pretrained() -> pathlib.Path: """Returns TRACER with EfficientNet v1 b7 encoder pretrained model location Returns: - pathlib.Path to model location + pathlib.Path: model location """ return downloader("tracer_b7.pth") -def tracer_hair_pretrained() -> pathlib.Path: - """Returns TRACER with EfficientNet v1 b7 encoder model for hair segmentation location +def scene_classifier_pretrained() -> pathlib.Path: + """Returns scene classifier pretrained model location + This model is used to classify scenes into 3 categories: hard, soft, digital + + hard - scenes with hard edges, such as objects, buildings, etc. + soft - scenes with soft edges, such as portraits, hairs, animal, etc. + digital - digital scenes, such as screenshots, graphics, etc. + + more info: https://huggingface.co/Carve/scene_classifier + + Returns: + pathlib.Path: model location + """ + return downloader("scene_classifier.pth") + + +def yolov4_coco_pretrained() -> pathlib.Path: + """Returns yolov4 classifier pretrained model location + This model is used to classify objects in images. + + Training dataset: COCO 2017 + Training classes: 80 + + It's a modified version of the original model from https://github.com/Tianxiaomo/pytorch-YOLOv4 (pytorch) + We have only added coco classnames to the model. + + Returns: + pathlib.Path to model location + """ + return downloader("yolov4_coco_with_classes.pth") + + +def cascadepsp_pretrained() -> pathlib.Path: + """Returns cascade psp pretrained model location + This model is used to refine segmentation masks. + + Training dataset: MSRA-10K, DUT-OMRON, ECSSD and FSS-1000 + more info: https://huggingface.co/Carve/cascadepsp Returns: pathlib.Path to model location """ - return downloader("tracer_hair.pth") + return downloader("cascadepsp.pth") def download_all(): @@ -68,3 +106,6 @@ def download_all(): deeplab_pretrained() basnet_pretrained() tracer_b7_pretrained() + scene_classifier_pretrained() + yolov4_coco_pretrained() + cascadepsp_pretrained() diff --git a/carvekit/ml/wrap/basnet.py b/carvekit/ml/wrap/basnet.py index 9912e81..836de7f 100644 --- a/carvekit/ml/wrap/basnet.py +++ b/carvekit/ml/wrap/basnet.py @@ -1,6 +1,8 @@ """ Source url: https://github.com/OPHoperHPO/image-background-remove-tool + Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. + License: Apache License 2.0 """ import pathlib @@ -34,12 +36,11 @@ def __init__( Initialize the BASNET model Args: - device: processing device - input_image_size: input image size - batch_size: the number of images that the neural network processes in one run - load_pretrained: loading pretrained model - fp16: use fp16 precision // not supported at this moment - + device (Literal[cpu, cuda], default=cpu): processing device + input_image_size (Union[List[int], int], default=320): input image size + batch_size (int, default=10): the number of images that the neural network processes in one run + load_pretrained (bool, default=True): loading pretrained model + fp16 (bool, default=True): use fp16 precision **not supported at this moment** """ super(BASNET, self).__init__(n_channels=3, n_classes=1) self.device = device @@ -60,10 +61,10 @@ def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor: Transform input image to suitable data format for neural network Args: - data: input image + data (PIL.Image.Image): input image Returns: - input for neural network + torch.Tensor: input for neural network """ resized = data.resize(self.input_image_size) @@ -81,18 +82,18 @@ def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor: @staticmethod def data_postprocessing( - data: torch.tensor, original_image: PIL.Image.Image + data: torch.Tensor, original_image: PIL.Image.Image ) -> PIL.Image.Image: """ Transforms output data from neural network to suitable data format for using with other components of this framework. Args: - data: output data from neural network - original_image: input image which was used for predicted data + data (torch.Tensor): output data from neural network + original_image (PIL.Image.Image): input image which was used for predicted data Returns: - Segmentation mask as PIL Image instance + PIL.Image.Image: Segmentation mask as `PIL Image` instance """ data = data.unsqueeze(0) @@ -109,22 +110,22 @@ def __call__( self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] ) -> List[PIL.Image.Image]: """ - Passes input images through neural network and returns segmentation masks as PIL.Image.Image instances + Passes input images through neural network and returns segmentation masks as `PIL.Image.Image` instances Args: - images: input images + images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images Returns: - segmentation masks as for input images, as PIL.Image.Image instances + List[PIL.Image.Image]: segmentation masks as for input images, as `PIL.Image.Image` instances """ collect_masks = [] for image_batch in batch_generator(images, self.batch_size): - images = thread_pool_processing( + converted_images = thread_pool_processing( lambda x: convert_image(load_image(x)), image_batch ) batches = torch.vstack( - thread_pool_processing(self.data_preprocessing, images) + thread_pool_processing(self.data_preprocessing, converted_images) ) with torch.no_grad(): batches = batches.to(self.device) @@ -134,8 +135,8 @@ def __call__( masks_cpu = masks.cpu() del d2, d3, d4, d5, d6, d7, d8, batches, masks masks = thread_pool_processing( - lambda x: self.data_postprocessing(masks_cpu[x], images[x]), - range(len(images)), + lambda x: self.data_postprocessing(masks_cpu[x], converted_images[x]), + range(len(converted_images)), ) collect_masks += masks return collect_masks diff --git a/carvekit/ml/wrap/cascadepsp.py b/carvekit/ml/wrap/cascadepsp.py new file mode 100644 index 0000000..1d0fc9a --- /dev/null +++ b/carvekit/ml/wrap/cascadepsp.py @@ -0,0 +1,310 @@ +""" +Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0 +""" +import pathlib +import warnings + +import PIL +import numpy as np +import torch +from PIL import Image +from torchvision import transforms +from typing import Union, List + +from carvekit.ml.arch.cascadepsp.pspnet import RefinementModule +from carvekit.ml.arch.cascadepsp.utils import ( + process_im_single_pass, + process_high_res_im, +) +from carvekit.ml.files.models_loc import cascadepsp_pretrained +from carvekit.utils.image_utils import convert_image, load_image +from carvekit.utils.models_utils import get_precision_autocast, cast_network +from carvekit.utils.pool_utils import batch_generator, thread_pool_processing + +__all__ = ["CascadePSP"] + + +class CascadePSP(RefinementModule): + """ + CascadePSP to refine the mask from segmentation network + """ + + def __init__( + self, + device="cpu", + input_tensor_size: int = 900, + batch_size: int = 1, + load_pretrained: bool = True, + fp16: bool = False, + mask_binary_threshold=127, + global_step_only=False, + processing_accelerate_image_size=2048, + ): + """ + Initialize the CascadePSP model + + Args: + device: processing device + input_tensor_size: input image size + batch_size: the number of images that the neural network processes in one run + load_pretrained: loading pretrained model + fp16: use half precision + global_step_only: if True, only global step will be used for prediction. See paper for details. + mask_binary_threshold: threshold for binary mask, default 70, set to 0 for no threshold + processing_accelerate_image_size: thumbnail size for image processing acceleration. Set to 0 to disable + + """ + super().__init__() + self.fp16 = fp16 + self.device = device + self.batch_size = batch_size + self.mask_binary_threshold = mask_binary_threshold + self.global_step_only = global_step_only + self.processing_accelerate_image_size = processing_accelerate_image_size + self.input_tensor_size = input_tensor_size + + self.to(device) + if batch_size > 1: + warnings.warn( + "Batch size > 1 is experimental feature for CascadePSP." + " Please, don't use it if you have GPU with small memory!" + ) + if load_pretrained: + self.load_state_dict( + torch.load(cascadepsp_pretrained(), map_location=self.device) + ) + self.eval() + + self._image_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + + self._seg_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5], std=[0.5]), + ] + ) + + def data_preprocessing(self, data: Union[PIL.Image.Image]) -> torch.FloatTensor: + """ + Transform input image to suitable data format for neural network + + Args: + data: input image + + Returns: + input for neural network + + """ + preprocessed_data = data.copy() + if self.batch_size == 1 and self.processing_accelerate_image_size > 0: + # Okay, we have only one image, so + # we can use image processing acceleration for accelerate high resolution image processing + preprocessed_data.thumbnail( + ( + self.processing_accelerate_image_size, + self.processing_accelerate_image_size, + ) + ) + elif self.batch_size == 1: + pass # No need to do anything + elif self.batch_size > 1 and self.global_step_only is True: + # If we have more than one image and we use only global step, + # there aren't any reason to use image processing acceleration, + # because we will use only global step for prediction and anyway it will be resized to input_tensor_size + preprocessed_data = preprocessed_data.resize( + (self.input_tensor_size, self.input_tensor_size) + ) + elif ( + self.batch_size > 1 + and self.global_step_only is False + and self.processing_accelerate_image_size > 0 + ): + # If we have more than one image and we use local step, + # we can use image processing acceleration for accelerate high resolution image processing + # but we need to resize image to processing_accelerate_image_size to stack it with other images + preprocessed_data = preprocessed_data.resize( + ( + self.processing_accelerate_image_size, + self.processing_accelerate_image_size, + ) + ) + elif ( + self.batch_size > 1 + and self.global_step_only is False + and not (self.processing_accelerate_image_size > 0) + ): + raise ValueError( + "If you use local step with batch_size > 2, " + "you need to set processing_accelerate_image_size > 0," + "since we cannot stack images with different sizes to one batch" + ) + else: # some extra cases + preprocessed_data = preprocessed_data.resize( + ( + self.processing_accelerate_image_size, + self.processing_accelerate_image_size, + ) + ) + + if data.mode == "RGB": + preprocessed_data = self._image_transform( + np.array(preprocessed_data) + ).unsqueeze(0) + elif data.mode == "L": + preprocessed_data = np.array(preprocessed_data) + if 0 < self.mask_binary_threshold <= 255: + preprocessed_data = ( + preprocessed_data > self.mask_binary_threshold + ).astype(np.uint8) * 255 + elif self.mask_binary_threshold > 255 or self.mask_binary_threshold < 0: + warnings.warn( + "mask_binary_threshold should be in range [0, 255], " + "but got {}. Disabling mask_binary_threshold!".format( + self.mask_binary_threshold + ) + ) + + preprocessed_data = self._seg_transform(preprocessed_data).unsqueeze( + 0 + ) # [H,W,1] + + return preprocessed_data + + @staticmethod + def data_postprocessing( + data: torch.Tensor, mask: PIL.Image.Image + ) -> PIL.Image.Image: + """ + Transforms output data from neural network to suitable data + format for using with other components of this framework. + + Args: + data: output data from neural network + mask: input mask + + Returns: + Segmentation mask as PIL Image instance + + """ + refined_mask = (data[0, :, :].cpu().numpy() * 255).astype("uint8") + return Image.fromarray(refined_mask).convert("L").resize(mask.size) + + def safe_forward(self, im, seg, inter_s8=None, inter_s4=None): + """ + Slightly pads the input image such that its length is a multiple of 8 + """ + b, _, ph, pw = seg.shape + if (ph % 8 != 0) or (pw % 8 != 0): + newH = (ph // 8 + 1) * 8 + newW = (pw // 8 + 1) * 8 + p_im = torch.zeros(b, 3, newH, newW, device=im.device) + p_seg = torch.zeros(b, 1, newH, newW, device=im.device) - 1 + + p_im[:, :, 0:ph, 0:pw] = im + p_seg[:, :, 0:ph, 0:pw] = seg + im = p_im + seg = p_seg + + if inter_s8 is not None: + p_inter_s8 = torch.zeros(b, 1, newH, newW, device=im.device) - 1 + p_inter_s8[:, :, 0:ph, 0:pw] = inter_s8 + inter_s8 = p_inter_s8 + if inter_s4 is not None: + p_inter_s4 = torch.zeros(b, 1, newH, newW, device=im.device) - 1 + p_inter_s4[:, :, 0:ph, 0:pw] = inter_s4 + inter_s4 = p_inter_s4 + + images = super().__call__(im, seg, inter_s8, inter_s4) + return_im = {} + + for key in ["pred_224", "pred_28_3", "pred_56_2"]: + return_im[key] = images[key][:, :, 0:ph, 0:pw] + del images + + return return_im + + def __call__( + self, + images: List[Union[str, pathlib.Path, PIL.Image.Image]], + masks: List[Union[str, pathlib.Path, PIL.Image.Image]], + ) -> List[PIL.Image.Image]: + """ + Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances + + Args: + images: input images + masks: Segmentation masks to refine + + Returns: + segmentation masks as for input images, as PIL.Image.Image instances + + """ + + if len(images) != len(masks): + raise ValueError( + "Len of specified arrays of images and trimaps should be equal!" + ) + + collect_masks = [] + autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) + with autocast: + cast_network(self, dtype) + for idx_batch in batch_generator(range(len(images)), self.batch_size): + inpt_images = thread_pool_processing( + lambda x: convert_image(load_image(images[x])), idx_batch + ) + + inpt_masks = thread_pool_processing( + lambda x: convert_image(load_image(masks[x]), mode="L"), idx_batch + ) + + inpt_img_batches = thread_pool_processing( + self.data_preprocessing, inpt_images + ) + inpt_masks_batches = thread_pool_processing( + self.data_preprocessing, inpt_masks + ) + if self.batch_size > 1: # We need to stack images, if batch_size > 1 + inpt_img_batches = torch.vstack(inpt_img_batches) + inpt_masks_batches = torch.vstack(inpt_masks_batches) + else: + inpt_img_batches = inpt_img_batches[ + 0 + ] # Get only one image from list + inpt_masks_batches = inpt_masks_batches[0] + + with torch.no_grad(): + inpt_img_batches = inpt_img_batches.to(self.device) + inpt_masks_batches = inpt_masks_batches.to(self.device) + if self.global_step_only: + refined_batches = process_im_single_pass( + self, + inpt_img_batches, + inpt_masks_batches, + self.input_tensor_size, + ) + + else: + refined_batches = process_high_res_im( + self, + inpt_img_batches, + inpt_masks_batches, + self.input_tensor_size, + ) + + refined_masks = refined_batches.cpu() + del (inpt_img_batches, inpt_masks_batches, refined_batches) + collect_masks += thread_pool_processing( + lambda x: self.data_postprocessing(refined_masks[x], inpt_masks[x]), + range(len(inpt_masks)), + ) + return collect_masks diff --git a/carvekit/ml/wrap/deeplab_v3.py b/carvekit/ml/wrap/deeplab_v3.py index 4b19542..d570795 100644 --- a/carvekit/ml/wrap/deeplab_v3.py +++ b/carvekit/ml/wrap/deeplab_v3.py @@ -1,6 +1,8 @@ """ Source url: https://github.com/OPHoperHPO/image-background-remove-tool + Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. + License: Apache License 2.0 """ import pathlib @@ -29,14 +31,14 @@ def __init__( fp16: bool = False, ): """ - Initialize the DeepLabV3 model + Initialize the `DeepLabV3` model Args: - device: processing device - input_image_size: input image size - batch_size: the number of images that the neural network processes in one run - load_pretrained: loading pretrained model - fp16: use half precision + device (Literal[cpu, cuda], default=cpu): processing device + input_image_size (): input image size + batch_size (int, default=10): the number of images that the neural network processes in one run + load_pretrained (bool, default=True): loading pretrained model + fp16 (bool, default=False): use half precision """ self.device = device @@ -69,9 +71,7 @@ def to(self, device: str): Moves neural network to specified processing device Args: - device (:class:`torch.device`): the desired device. - Returns: - None + device (Literal[cpu, cuda]): the desired device. """ self.network.to(device) @@ -81,10 +81,10 @@ def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor: Transform input image to suitable data format for neural network Args: - data: input image + data (PIL.Image.Image): input image Returns: - input for neural network + torch.Tensor: input for neural network """ copy = data.copy() @@ -93,18 +93,18 @@ def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor: @staticmethod def data_postprocessing( - data: torch.tensor, original_image: PIL.Image.Image + data: torch.Tensor, original_image: PIL.Image.Image ) -> PIL.Image.Image: """ Transforms output data from neural network to suitable data format for using with other components of this framework. Args: - data: output data from neural network - original_image: input image which was used for predicted data + data (torch.Tensor): output data from neural network + original_image (PIL.Image.Image): input image which was used for predicted data Returns: - Segmentation mask as PIL Image instance + PIL.Image.Image: Segmentation mask as `PIL Image` instance """ return ( @@ -115,13 +115,13 @@ def __call__( self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] ) -> List[PIL.Image.Image]: """ - Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances + Passes input images though neural network and returns segmentation masks as `PIL.Image.Image` instances Args: - images: input images + images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images Returns: - segmentation masks as for input images, as PIL.Image.Image instances + List[PIL.Image.Image]: segmentation masks as for input images, as `PIL.Image.Image` instances """ collect_masks = [] @@ -129,10 +129,12 @@ def __call__( with autocast: cast_network(self.network, dtype) for image_batch in batch_generator(images, self.batch_size): - images = thread_pool_processing( + converted_images = thread_pool_processing( lambda x: convert_image(load_image(x)), image_batch ) - batches = thread_pool_processing(self.data_preprocessing, images) + batches = thread_pool_processing( + self.data_preprocessing, converted_images + ) with torch.no_grad(): masks = [ self.network(i.to(self.device).unsqueeze(0))["out"][0] @@ -143,8 +145,8 @@ def __call__( ] del batches masks = thread_pool_processing( - lambda x: self.data_postprocessing(masks[x], images[x]), - range(len(images)), + lambda x: self.data_postprocessing(masks[x], converted_images[x]), + range(len(converted_images)), ) collect_masks += masks return collect_masks diff --git a/carvekit/ml/wrap/fba_matting.py b/carvekit/ml/wrap/fba_matting.py index c285df0..19a2659 100644 --- a/carvekit/ml/wrap/fba_matting.py +++ b/carvekit/ml/wrap/fba_matting.py @@ -43,12 +43,14 @@ def __init__( Initialize the FBAMatting model Args: - device: processing device - input_tensor_size: input image size - batch_size: the number of images that the neural network processes in one run - encoder: neural network encoder head - load_pretrained: loading pretrained model - fp16: use half precision + device (Literal[cpu, cuda], default=cpu): processing device + input_tensor_size (Union[List[int], int], default=2048): input image size + batch_size (int, default=2): the number of images that the neural network processes in one run + encoder (str, default=resnet50_GN_WS): neural network encoder head + .. TODO:: + Add more encoders to documentation as Literal typehint. + load_pretrained (bool, default=True): loading pretrained model + fp16 (bool, default=False): use half precision """ super(FBAMatting, self).__init__(encoder=encoder) @@ -71,10 +73,10 @@ def data_preprocessing( Transform input image to suitable data format for neural network Args: - data: input image + data (Union[PIL.Image.Image, np.ndarray]): input image Returns: - input for neural network + Tuple[torch.FloatTensor, torch.FloatTensor]: input for neural network """ resized = data.copy() @@ -114,18 +116,18 @@ def data_preprocessing( @staticmethod def data_postprocessing( - data: torch.tensor, trimap: PIL.Image.Image + data: torch.Tensor, trimap: PIL.Image.Image ) -> PIL.Image.Image: """ Transforms output data from neural network to suitable data format for using with other components of this framework. Args: - data: output data from neural network - trimap: Map with the area we need to refine + data (torch.Tensor): output data from neural network + trimap (PIL.Image.Image): Map with the area we need to refine Returns: - Segmentation mask as PIL Image instance + PIL.Image.Image: Segmentation mask """ if trimap.mode != "L": @@ -149,11 +151,11 @@ def __call__( Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances Args: - images: input images - trimaps: Maps with the areas we need to refine + images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images + trimaps (List[Union[str, pathlib.Path, PIL.Image.Image]]): Maps with the areas we need to refine Returns: - segmentation masks as for input images, as PIL.Image.Image instances + List[PIL.Image.Image]: segmentation masks as for input images """ diff --git a/carvekit/ml/wrap/scene_classifier.py b/carvekit/ml/wrap/scene_classifier.py new file mode 100644 index 0000000..75c0b39 --- /dev/null +++ b/carvekit/ml/wrap/scene_classifier.py @@ -0,0 +1,150 @@ +""" +Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0 +""" +import pathlib + +import PIL.Image +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms +from typing import List, Union, Tuple +from torch.autograd import Variable + +from carvekit.ml.files.models_loc import scene_classifier_pretrained +from carvekit.utils.image_utils import load_image, convert_image +from carvekit.utils.models_utils import get_precision_autocast, cast_network +from carvekit.utils.pool_utils import thread_pool_processing, batch_generator + +__all__ = ["SceneClassifier"] + + +class SceneClassifier: + """ + SceneClassifier model interface + + Description: + Performs a primary analysis of the image in order to select the necessary method for removing the background. + The choice is made by classifying the scene type. + + The output can be the following types: + - hard + - soft + - digital + + """ + + def __init__( + self, + topk: int = 1, + device="cpu", + batch_size: int = 4, + fp16: bool = False, + model_path: Union[str, pathlib.Path] = None, + ): + """ + Initialize the Scene Classifier. + + Args: + topk: number of top classes to return + device: processing device + batch_size: the number of images that the neural network processes in one run + fp16: use fp16 precision + + """ + if model_path is None: + model_path = scene_classifier_pretrained() + self.topk = topk + self.fp16 = fp16 + self.device = device + self.batch_size = batch_size + + self.transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + state_dict = torch.load(model_path, map_location=device) + self.model = state_dict["model"] + self.class_to_idx = state_dict["class_to_idx"] + self.idx_to_class = {v: k for k, v in self.class_to_idx.items()} + self.model.to(device) + self.model.eval() + + def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: + """ + Transform input image to suitable data format for neural network + + Args: + data: input image + + Returns: + input for neural network + + """ + + return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor) + + def data_postprocessing(self, data: torch.Tensor) -> Tuple[List[str], List[float]]: + """ + Transforms output data from neural network to suitable data + format for using with other components of this framework. + + Args: + data: output data from neural network + + Returns: + Top-k class of scene type, probability of these classes + + """ + ps = F.softmax(data.float(), dim=0) + topk = ps.cpu().topk(self.topk) + + probs, classes = (e.data.numpy().squeeze().tolist() for e in topk) + if isinstance(classes, int): + classes = [classes] + probs = [probs] + return list(map(lambda x: self.idx_to_class[x], classes)), probs + + def __call__( + self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] + ) -> Tuple[List[str], List[float]]: + """ + Passes input images though neural network and returns class predictions. + + Args: + images: input images + + Returns: + Top-k class of scene type, probability of these classes for every passed image + + """ + collect_masks = [] + autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) + with autocast: + cast_network(self.model, dtype) + for image_batch in batch_generator(images, self.batch_size): + converted_images = thread_pool_processing( + lambda x: convert_image(load_image(x)), image_batch + ) + batches = torch.vstack( + thread_pool_processing(self.data_preprocessing, converted_images) + ) + with torch.no_grad(): + batches = Variable(batches).to(self.device) + masks = self.model.forward(batches) + masks_cpu = masks.cpu() + del batches, masks + masks = thread_pool_processing( + lambda x: self.data_postprocessing(masks_cpu[x]), + range(len(converted_images)), + ) + collect_masks += masks + + return collect_masks diff --git a/carvekit/ml/wrap/tracer_b7.py b/carvekit/ml/wrap/tracer_b7.py index 20a8e45..214b095 100644 --- a/carvekit/ml/wrap/tracer_b7.py +++ b/carvekit/ml/wrap/tracer_b7.py @@ -4,19 +4,19 @@ License: Apache License 2.0 """ import pathlib -import warnings from typing import List, Union + import PIL.Image import numpy as np import torch import torchvision.transforms as transforms from PIL import Image -from carvekit.ml.arch.tracerb7.tracer import TracerDecoder from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7 -from carvekit.ml.files.models_loc import tracer_b7_pretrained, tracer_hair_pretrained -from carvekit.utils.models_utils import get_precision_autocast, cast_network +from carvekit.ml.arch.tracerb7.tracer import TracerDecoder +from carvekit.ml.files.models_loc import tracer_b7_pretrained from carvekit.utils.image_utils import load_image, convert_image +from carvekit.utils.models_utils import get_precision_autocast, cast_network from carvekit.utils.pool_utils import thread_pool_processing, batch_generator __all__ = ["TracerUniversalB7"] @@ -35,16 +35,16 @@ def __init__( model_path: Union[str, pathlib.Path] = None, ): """ - Initialize the U2NET model + Initialize the TRACER model Args: - layers_cfg: neural network layers configuration - device: processing device - input_image_size: input image size - batch_size: the number of images that the neural network processes in one run - load_pretrained: loading pretrained model - fp16: use fp16 precision - + device (Literal[cpu, cuda], default=cpu): processing device + input_image_size (Union[List[int], int], default=640): input image size + batch_size(int, default=4): the number of images that the neural network processes in one run + load_pretrained(bool, default=True): loading pretrained model + fp16 (bool, default=False): use fp16 precision + model_path (Union[str, pathlib.Path], default=None): path to the model + .. note:: REDO """ if model_path is None: model_path = tracer_b7_pretrained() @@ -82,10 +82,10 @@ def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: Transform input image to suitable data format for neural network Args: - data: input image + data (PIL.Image.Image): input image Returns: - input for neural network + torch.FloatTensor: input for neural network """ @@ -93,18 +93,18 @@ def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: @staticmethod def data_postprocessing( - data: torch.tensor, original_image: PIL.Image.Image + data: torch.Tensor, original_image: PIL.Image.Image ) -> PIL.Image.Image: """ Transforms output data from neural network to suitable data format for using with other components of this framework. Args: - data: output data from neural network - original_image: input image which was used for predicted data + data (torch.Tensor): output data from neural network + original_image (PIL.Image.Image): input image which was used for predicted data Returns: - Segmentation mask as PIL Image instance + PIL.Image.Image: Segmentation mask """ output = (data.type(torch.FloatTensor).detach().cpu().numpy() * 255.0).astype( @@ -122,10 +122,10 @@ def __call__( Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances Args: - images: input images + images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images Returns: - segmentation masks as for input images, as PIL.Image.Image instances + List[PIL.Image.Image]: segmentation masks as for input images """ collect_masks = [] @@ -133,11 +133,11 @@ def __call__( with autocast: cast_network(self, dtype) for image_batch in batch_generator(images, self.batch_size): - images = thread_pool_processing( + converted_images = thread_pool_processing( lambda x: convert_image(load_image(x)), image_batch ) batches = torch.vstack( - thread_pool_processing(self.data_preprocessing, images) + thread_pool_processing(self.data_preprocessing, converted_images) ) with torch.no_grad(): batches = batches.to(self.device) @@ -145,34 +145,11 @@ def __call__( masks_cpu = masks.cpu() del batches, masks masks = thread_pool_processing( - lambda x: self.data_postprocessing(masks_cpu[x], images[x]), - range(len(images)), + lambda x: self.data_postprocessing( + masks_cpu[x], converted_images[x] + ), + range(len(converted_images)), ) collect_masks += masks return collect_masks - - -class TracerHair(TracerUniversalB7): - """TRACER HAIR model interface""" - - def __init__( - self, - device="cpu", - input_image_size: Union[List[int], int] = 640, - batch_size: int = 4, - load_pretrained: bool = True, - fp16: bool = False, - model_path: Union[str, pathlib.Path] = None, - ): - if model_path is None: - model_path = tracer_hair_pretrained() - warnings.warn("TracerHair has not public model yet. Don't use it!", UserWarning) - super(TracerHair, self).__init__( - device=device, - input_image_size=input_image_size, - batch_size=batch_size, - load_pretrained=load_pretrained, - fp16=fp16, - model_path=model_path, - ) diff --git a/carvekit/ml/wrap/u2net.py b/carvekit/ml/wrap/u2net.py index 7d126df..4a0eb57 100644 --- a/carvekit/ml/wrap/u2net.py +++ b/carvekit/ml/wrap/u2net.py @@ -4,6 +4,8 @@ License: Apache License 2.0 """ import pathlib +import warnings + from typing import List, Union import PIL.Image import numpy as np @@ -43,6 +45,8 @@ def __init__( """ super(U2NET, self).__init__(cfg_type=layers_cfg, out_ch=1) + if fp16: + warnings.warn("FP16 is not supported at this moment for U2NET model") self.device = device self.batch_size = batch_size if isinstance(input_image_size, list): @@ -54,6 +58,7 @@ def __init__( self.load_state_dict( torch.load(u2net_full_pretrained(), map_location=self.device) ) + self.eval() def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: @@ -61,10 +66,10 @@ def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: Transform input image to suitable data format for neural network Args: - data: input image + data (PIL.Image.Image): input image Returns: - input for neural network + torch.FloatTensor: input for neural network """ resized = data.resize(self.input_image_size, resample=3) @@ -82,18 +87,18 @@ def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: @staticmethod def data_postprocessing( - data: torch.tensor, original_image: PIL.Image.Image + data: torch.Tensor, original_image: PIL.Image.Image ) -> PIL.Image.Image: """ Transforms output data from neural network to suitable data format for using with other components of this framework. Args: - data: output data from neural network - original_image: input image which was used for predicted data + data (torch.Tensor): output data from neural network + original_image (PIL.Image.Image): input image which was used for predicted data Returns: - Segmentation mask as PIL Image instance + PIL.Image.Image: Segmentation mask as `PIL Image` instance """ data = data.unsqueeze(0) @@ -121,11 +126,11 @@ def __call__( """ collect_masks = [] for image_batch in batch_generator(images, self.batch_size): - images = thread_pool_processing( + converted_images = thread_pool_processing( lambda x: convert_image(load_image(x)), image_batch ) batches = torch.vstack( - thread_pool_processing(self.data_preprocessing, images) + thread_pool_processing(self.data_preprocessing, converted_images) ) with torch.no_grad(): batches = batches.to(self.device) @@ -133,8 +138,8 @@ def __call__( masks_cpu = masks.cpu() del d2, d3, d4, d5, d6, d7, batches, masks masks = thread_pool_processing( - lambda x: self.data_postprocessing(masks_cpu[x], images[x]), - range(len(images)), + lambda x: self.data_postprocessing(masks_cpu[x], converted_images[x]), + range(len(converted_images)), ) collect_masks += masks return collect_masks diff --git a/carvekit/ml/wrap/yolov4.py b/carvekit/ml/wrap/yolov4.py new file mode 100644 index 0000000..cf59233 --- /dev/null +++ b/carvekit/ml/wrap/yolov4.py @@ -0,0 +1,296 @@ +""" +Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0 +""" + +import pathlib + +import PIL.Image +import PIL.Image +import numpy as np +import pydantic +import torch +from torch.autograd import Variable +from typing import List, Union + +from carvekit.ml.arch.yolov4.models import Yolov4 +from carvekit.ml.arch.yolov4.utils import post_processing +from carvekit.ml.files.models_loc import yolov4_coco_pretrained +from carvekit.utils.image_utils import load_image, convert_image +from carvekit.utils.models_utils import get_precision_autocast, cast_network +from carvekit.utils.pool_utils import thread_pool_processing, batch_generator + +__all__ = ["YoloV4_COCO", "SimplifiedYoloV4"] + + +class Object(pydantic.BaseModel): + """Object class""" + + class_name: str + confidence: float + x1: int + y1: int + x2: int + y2: int + + +class YoloV4_COCO(Yolov4): + """YoloV4 COCO model wrapper""" + + def __init__( + self, + n_classes: int = 80, + device="cpu", + classes: List[str] = None, + input_image_size: Union[List[int], int] = 608, + batch_size: int = 4, + load_pretrained: bool = True, + fp16: bool = False, + model_path: Union[str, pathlib.Path] = None, + ): + """ + Initialize the YoloV4 COCO. + + Args: + n_classes: number of classes + device: processing device + input_image_size: input image size + batch_size: the number of images that the neural network processes in one run + fp16: use fp16 precision + model_path: path to model weights + load_pretrained: load pretrained weights + """ + if model_path is None: + model_path = yolov4_coco_pretrained() + self.fp16 = fp16 + self.device = device + self.batch_size = batch_size + if isinstance(input_image_size, list): + self.input_image_size = input_image_size[:2] + else: + self.input_image_size = (input_image_size, input_image_size) + + if load_pretrained: + state_dict = torch.load(model_path, map_location="cpu") + self.classes = state_dict["classes"] + super().__init__(n_classes=len(state_dict["classes"]), inference=True) + self.load_state_dict(state_dict["state"]) + else: + self.classes = classes + super().__init__(n_classes=n_classes, inference=True) + + self.to(device) + self.eval() + + def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: + """ + Transform input image to suitable data format for neural network + + Args: + data: input image + + Returns: + input for neural network + + """ + image = data.resize(self.input_image_size) + # noinspection PyTypeChecker + image = np.array(image).astype(np.float32) + image = image.transpose((2, 0, 1)) + image = image / 255.0 + image = torch.from_numpy(image).float() + return torch.unsqueeze(image, 0).type(torch.FloatTensor) + + def data_postprocessing( + self, data: List[torch.FloatTensor], images: List[PIL.Image.Image] + ) -> List[Object]: + """ + Transforms output data from neural network to suitable data + format for using with other components of this framework. + + Args: + data: output data from neural network + images: input images + + + Returns: + list of objects for each image + + """ + output = post_processing(0.4, 0.6, data) + images_objects = [] + for image_idx, image_objects in enumerate(output): + image_size = images[image_idx].size + objects = [] + for obj in image_objects: + objects.append( + Object( + class_name=self.classes[obj[6]], + confidence=obj[5], + x1=int(obj[0] * image_size[0]), + y1=int(obj[1] * image_size[1]), + x2=int(obj[2] * image_size[0]), + y2=int(obj[3] * image_size[1]), + ) + ) + images_objects.append(objects) + + return images_objects + + def __call__( + self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] + ) -> List[List[Object]]: + """ + Passes input images though neural network + + Args: + images: input images + + Returns: + list of objects for each image + + """ + collect_masks = [] + autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) + with autocast: + cast_network(self, dtype) + for image_batch in batch_generator(images, self.batch_size): + converted_images = thread_pool_processing( + lambda x: convert_image(load_image(x)), image_batch + ) + batches = torch.vstack( + thread_pool_processing(self.data_preprocessing, converted_images) + ) + with torch.no_grad(): + batches = Variable(batches).to(self.device) + out = super().__call__(batches) + out_cpu = [out_i.cpu() for out_i in out] + del batches, out + out = self.data_postprocessing(out_cpu, converted_images) + collect_masks += out + + return collect_masks + + +class SimplifiedYoloV4(YoloV4_COCO): + """ + The YoloV4 COCO classifier, but classifies only 7 supercategories. + + human - Scenes of people, such as portrait photographs + animals - Scenes with animals + objects - Scenes with normal objects + cars - Scenes with cars + other - Other scenes + """ + + db = { + "human": ["person"], + "animals": [ + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + ], + "cars": [ + "car", + "motorbike", + "bus", + "truck", + ], + "objects": [ + "bicycle", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "sofa", + "pottedplant", + "bed", + "diningtable", + "toilet", + "tvmonitor", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", + ], + "other": ["aeroplane", "train", "boat"], + } + + def data_postprocessing( + self, data: List[torch.FloatTensor], images: List[PIL.Image.Image] + ) -> List[List[str]]: + """ + Transforms output data from neural network to suitable data + format for using with other components of this framework. + + Args: + data: output data from neural network + images: input images + """ + objects = super().data_postprocessing(data, images) + new_output = [] + + for image_objects in objects: + new_objects = [] + for obj in image_objects: + for key, values in list(self.db.items()): + if obj.class_name in values: + new_objects.append(key) # We don't need bbox at this moment + new_output.append(new_objects) + + return new_output diff --git a/carvekit/pipelines/postprocessing/__init__.py b/carvekit/pipelines/postprocessing/__init__.py new file mode 100644 index 0000000..1de606e --- /dev/null +++ b/carvekit/pipelines/postprocessing/__init__.py @@ -0,0 +1,2 @@ +from carvekit.pipelines.postprocessing.matting import MattingMethod +from carvekit.pipelines.postprocessing.casmatting import CasMattingMethod diff --git a/carvekit/pipelines/postprocessing/casmatting.py b/carvekit/pipelines/postprocessing/casmatting.py new file mode 100644 index 0000000..d8eec79 --- /dev/null +++ b/carvekit/pipelines/postprocessing/casmatting.py @@ -0,0 +1,83 @@ +""" +Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0 +""" +from carvekit.ml.wrap.fba_matting import FBAMatting +from carvekit.ml.wrap.cascadepsp import CascadePSP +from typing import Union, List +from PIL import Image +from pathlib import Path +from carvekit.trimap.cv_gen import CV2TrimapGenerator +from carvekit.trimap.generator import TrimapGenerator +from carvekit.utils.mask_utils import apply_mask +from carvekit.utils.pool_utils import thread_pool_processing +from carvekit.utils.image_utils import load_image, convert_image + +__all__ = ["CasMattingMethod"] + + +class CasMattingMethod: + """ + Improve segmentation quality by refining segmentation with the CascadePSP model + and post-processing the segmentation with the FBAMatting model + """ + + def __init__( + self, + refining_module: Union[CascadePSP], + matting_module: Union[FBAMatting], + trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator], + device="cpu", + ): + """ + Initializes CasMattingMethod class. + + Args: + refining_module: Initialized refining network + matting_module: Initialized matting neural network class + trimap_generator: Initialized trimap generator class + device: Processing device used for applying mask to image + """ + self.device = device + self.refining_module = refining_module + self.matting_module = matting_module + self.trimap_generator = trimap_generator + + def __call__( + self, + images: List[Union[str, Path, Image.Image]], + masks: List[Union[str, Path, Image.Image]], + ): + """ + Passes data through apply_mask function + + Args: + images: list of images + masks: list pf masks + + Returns: + list of images + """ + if len(images) != len(masks): + raise ValueError("Images and Masks lists should have same length!") + images = thread_pool_processing(lambda x: convert_image(load_image(x)), images) + masks = thread_pool_processing( + lambda x: convert_image(load_image(x), mode="L"), masks + ) + refined_masks = self.refining_module(images, masks) + trimaps = thread_pool_processing( + lambda x: self.trimap_generator( + original_image=images[x], mask=refined_masks[x] + ), + range(len(images)), + ) + alpha = self.matting_module(images=images, trimaps=trimaps) + return list( + map( + lambda x: apply_mask( + image=images[x], mask=alpha[x], device=self.device + ), + range(len(images)), + ) + ) diff --git a/carvekit/pipelines/postprocessing.py b/carvekit/pipelines/postprocessing/matting.py similarity index 89% rename from carvekit/pipelines/postprocessing.py rename to carvekit/pipelines/postprocessing/matting.py index fc22451..cd91142 100644 --- a/carvekit/pipelines/postprocessing.py +++ b/carvekit/pipelines/postprocessing/matting.py @@ -32,9 +32,9 @@ def __init__( Initializes Matting Method class. Args: - matting_module: Initialized matting neural network class - trimap_generator: Initialized trimap generator class - device: Processing device used for applying mask to image + - `matting_module`: Initialized matting neural network class + - `trimap_generator`: Initialized trimap generator class + - `device`: Processing device used for applying mask to image """ self.device = device self.matting_module = matting_module @@ -49,11 +49,11 @@ def __call__( Passes data through apply_mask function Args: - images: list of images - masks: list pf masks + - `images`: list of images + - `masks`: list pf masks Returns: - list of images + list of images """ if len(images) != len(masks): raise ValueError("Images and Masks lists should have same length!") diff --git a/carvekit/pipelines/preprocessing/__init__.py b/carvekit/pipelines/preprocessing/__init__.py new file mode 100644 index 0000000..5355429 --- /dev/null +++ b/carvekit/pipelines/preprocessing/__init__.py @@ -0,0 +1,2 @@ +from carvekit.pipelines.preprocessing.stub import PreprocessingStub +from carvekit.pipelines.preprocessing.autoscene import AutoScene diff --git a/carvekit/pipelines/preprocessing/autoscene.py b/carvekit/pipelines/preprocessing/autoscene.py new file mode 100644 index 0000000..04138fb --- /dev/null +++ b/carvekit/pipelines/preprocessing/autoscene.py @@ -0,0 +1,85 @@ +""" +Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0 +""" +from pathlib import Path + +from PIL import Image +from typing import Union, List + +from carvekit.ml.wrap.scene_classifier import SceneClassifier +from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 +from carvekit.ml.wrap.u2net import U2NET + +__all__ = ["AutoScene"] + + +class AutoScene: + """AutoScene preprocessing method""" + + def __init__(self, scene_classifier: SceneClassifier): + """ + Args: + scene_classifier: SceneClassifier instance + """ + self.scene_classifier = scene_classifier + + @staticmethod + def select_net(scene: str): + """ + Selects the network to be used for segmentation based on the detected scene + + Args: + scene: scene name + """ + if scene == "hard": + return TracerUniversalB7 + elif scene == "soft": + return U2NET + elif scene == "digital": + return TracerUniversalB7 # TODO: not implemented yet + + def __call__(self, interface, images: List[Union[str, Path, Image.Image]]): + """ + Automatically detects the scene and selects the appropriate network for segmentation + + Args: + interface: Interface instance + images: list of images + + Returns: + list of masks + """ + scene_analysis = self.scene_classifier(images) + images_per_scene = {} + for i, image in enumerate(images): + scene_name = scene_analysis[i][0][0] + if scene_name not in images_per_scene: + images_per_scene[scene_name] = [] + images_per_scene[scene_name].append(image) + + masks_per_scene = {} + for scene_name, igs in list(images_per_scene.items()): + net = self.select_net(scene_name) + if isinstance(interface.segmentation_pipeline, net): + masks_per_scene[scene_name] = interface.segmentation_pipeline(igs) + else: + old_device = interface.segmentation_pipeline.device + interface.segmentation_pipeline.to( + "cpu" + ) # unload model from gpu, to avoid OOM + net_instance = net(device=old_device) + masks_per_scene[scene_name] = net_instance(igs) + del net_instance + interface.segmentation_pipeline.to(old_device) # load model back to gpu + + # restore one list of masks with the same order as images + masks = [] + for i, image in enumerate(images): + scene_name = scene_analysis[i][0][0] + masks.append( + masks_per_scene[scene_name][images_per_scene[scene_name].index(image)] + ) + + return masks diff --git a/carvekit/pipelines/preprocessing.py b/carvekit/pipelines/preprocessing/stub.py similarity index 81% rename from carvekit/pipelines/preprocessing.py rename to carvekit/pipelines/preprocessing/stub.py index 3d1e848..ea1b8b9 100644 --- a/carvekit/pipelines/preprocessing.py +++ b/carvekit/pipelines/preprocessing/stub.py @@ -16,11 +16,11 @@ class PreprocessingStub: def __call__(self, interface, images: List[Union[str, Path, Image.Image]]): """ - Passes data though interface.segmentation_pipeline() method + Passes data though `interface.segmentation_pipeline()` method Args: - interface: Interface instance - images: list of images + - `interface`: Interface instance + - `images`: list of images Returns: the result of passing data through segmentation_pipeline method of interface diff --git a/carvekit/trimap/add_ops.py b/carvekit/trimap/add_ops.py index dfb37ca..c1f8313 100644 --- a/carvekit/trimap/add_ops.py +++ b/carvekit/trimap/add_ops.py @@ -13,14 +13,14 @@ def prob_filter(mask: Image.Image, prob_threshold=231) -> Image.Image: Applies a filter to the mask by the probability of locating an object in the object area. Args: - prob_threshold: Threshold of probability for mark area as background. - mask: Predicted object mask + prob_threshold (int, default=231): Threshold of probability for mark area as background. + mask (Image.Image): Predicted object mask Raises: - ValueError if mask or trimap has wrong color mode + ValueError: if mask or trimap has wrong color mode Returns: - Generated trimap for image. + Image.Image: generated trimap for image. """ if mask.mode != "L": raise ValueError("Input mask has wrong color mode.") @@ -38,15 +38,15 @@ def prob_as_unknown_area( Marks any uncertainty in the seg mask as an unknown region. Args: - prob_threshold: Threshold of probability for mark area as unknown. - trimap: Generated trimap. - mask: Predicted object mask + prob_threshold (int, default=255): Threshold of probability for mark area as unknown. + trimap (Image.Image): Generated trimap. + mask (Image.Image): Predicted object mask Raises: - ValueError if mask or trimap has wrong color mode + ValueError: if mask or trimap has wrong color mode Returns: - Generated trimap for image. + Image.Image: Generated trimap for image. """ if mask.mode != "L" or trimap.mode != "L": raise ValueError("Input mask has wrong color mode.") @@ -63,13 +63,12 @@ def post_erosion(trimap: Image.Image, erosion_iters=1) -> Image.Image: Performs erosion on the mask and marks the resulting area as an unknown region. Args: - erosion_iters: The number of iterations of erosion that + erosion_iters (int, default=1): The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area - trimap: Generated trimap. - mask: Predicted object mask + trimap (Image.Image): Generated trimap. Returns: - Generated trimap for image. + Image.Image: Generated trimap for image. """ if trimap.mode != "L": raise ValueError("Input mask has wrong color mode.") diff --git a/carvekit/trimap/cv_gen.py b/carvekit/trimap/cv_gen.py index fc2c229..8323751 100644 --- a/carvekit/trimap/cv_gen.py +++ b/carvekit/trimap/cv_gen.py @@ -14,9 +14,9 @@ def __init__(self, kernel_size: int = 30, erosion_iters: int = 1): Initialize a new CV2TrimapGenerator instance Args: - kernel_size: The size of the offset from the object mask + kernel_size (int, default=30): The size of the offset from the object mask in pixels when an unknown area is detected in the trimap - erosion_iters: The number of iterations of erosion that + erosion_iters (int, default=1: The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area """ self.kernel_size = kernel_size @@ -30,11 +30,11 @@ def __call__( Based on cv2 erosion algorithm. Args: - original_image: Original image - mask: Predicted object mask + original_image (PIL.Image.Image): Original image + mask (PIL.Image.Image): Predicted object mask Returns: - Generated trimap for image. + PIL.Image.Image: Generated trimap for image. """ if mask.mode != "L": raise ValueError("Input mask has wrong color mode.") diff --git a/carvekit/trimap/generator.py b/carvekit/trimap/generator.py index 0656f45..cbabea6 100644 --- a/carvekit/trimap/generator.py +++ b/carvekit/trimap/generator.py @@ -16,11 +16,11 @@ def __init__( Initialize a TrimapGenerator instance Args: - prob_threshold: Probability threshold at which the + prob_threshold (int, default=231): Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied - kernel_size: The size of the offset from the object mask + kernel_size (int, default=30): The size of the offset from the object mask in pixels when an unknown area is detected in the trimap - erosion_iters: The number of iterations of erosion that + erosion_iters (int, default=5): The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area """ super().__init__(kernel_size, erosion_iters=0) @@ -31,12 +31,13 @@ def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Imag """ Generates trimap based on predicted object mask to refine object mask borders. Based on cv2 erosion algorithm and additional prob. filters. + Args: - original_image: Original image - mask: Predicted object mask + original_image (Image.Image): Original image + mask (Image.Image): Predicted object mask Returns: - Generated trimap for image. + Image.Image: Generated trimap for image. """ filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold) trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask) diff --git a/carvekit/utils/download_models.py b/carvekit/utils/download_models.py index b1b52ad..de13778 100644 --- a/carvekit/utils/download_models.py +++ b/carvekit/utils/download_models.py @@ -45,12 +45,25 @@ "revision": "d8a8fd9e7b3fa0d2f1506fe7242966b34381e9c5", "filename": "tracer_b7.pth", }, - "tracer_hair.pth": { - "repository": "Carve/tracer_b7", - "revision": "d8a8fd9e7b3fa0d2f1506fe7242966b34381e9c5", - "filename": "tracer_b7.pth", # TODO don't forget change this link!! + "scene_classifier.pth": { + "repository": "Carve/scene_classifier", + "revision": "71c8e4c771dd5a20ff0c5c9e3c8f1c9cf8082740", + "filename": "scene_classifier.pth", + }, + "yolov4_coco_with_classes.pth": { + "repository": "Carve/yolov4_coco", + "revision": "e3fc9cd22f86e456d2749d1ae148400f2f950fb3", + "filename": "yolov4_coco_with_classes.pth", + }, + "cascadepsp.pth": { + "repository": "Carve/cascadepsp", + "revision": "3ca1e5e432344b1277bc88d1c6d4265c46cff62f", + "filename": "cascadepsp.pth", }, } +""" +All data needed to build path relative to huggingface.co for model download +""" MODELS_CHECKSUMS = { "basnet.pth": "e409cb709f4abca87cb11bd44a9ad3f909044a917977ab65244b4c94dd33" @@ -63,9 +76,15 @@ "bea1533fda5ee70a909b934a9bd495b432cef89d629f00a07858a517742476fa8b346de24f7", "tracer_b7.pth": "c439c5c12d4d43d5f9be9ec61e68b2e54658a541bccac2577ef5a54fb252b6e8415d41f7e" "c2487033d0c02b4dd08367958e4e62091318111c519f93e2632be7b", - "tracer_hair.pth": "5c2fb9973fc42fa6208920ffa9ac233cc2ea9f770b24b4a96969d3449aed7ac89e6d37e" - "e486a13e63be5499f2df6ccef1109e9e8797d1326207ac89b2f39a7cf", + "scene_classifier.pth": "6d8692510abde453b406a1fea557afdea62fd2a2a2677283a3ecc2" + "341a4895ee99ed65cedcb79b80775db14c3ffcfc0aad2caec1d85140678852039d2d4e76b4", + "yolov4_coco_with_classes.pth": "44b6ec2dd35dc3802bf8c512002f76e00e97bfbc86bc7af6de2fafce229a41b4ca" + "12c6f3d7589278c71cd4ddd62df80389b148c19b84fa03216905407a107fff", + "cascadepsp.pth": "3f895f5126d80d6f73186f045557ea7c8eab4dfa3d69a995815bb2c03d564573f36c474f04d7bf0022a27829f583a1a793b036adf801cb423e41a4831b830122", } +""" +Model -> checksum dictionary +""" def sha512_checksum_calc(file: Path) -> str: @@ -73,7 +92,7 @@ def sha512_checksum_calc(file: Path) -> str: Calculates the SHA512 hash digest of a file on fs Args: - file: Path to the file + file (Path): Path to the file Returns: SHA512 hash digest of a file. @@ -86,6 +105,10 @@ def sha512_checksum_calc(file: Path) -> str: class CachedDownloader: + """ + Metaclass for models downloaders. + """ + __metaclass__ = ABCMeta @property @@ -96,9 +119,24 @@ def name(self) -> str: @property @abstractmethod def fallback_downloader(self) -> Optional["CachedDownloader"]: + """ + Property MAY be overriden in subclasses. + Used in case if subclass failed to download model. So preferred downloader SHOULD be placed higher in the hierarchy. + Less preferred downloader SHOULD be provided by this property. + """ pass def download_model(self, file_name: str) -> Path: + """ + Downloads model from the internet and saves it to the cache. + + Behavior: + If model is already downloaded it will be loaded from the cache. + + If model is already downloaded, but checksum is invalid, it will be downloaded again. + + If model download failed, fallback downloader will be used. + """ try: return self.download_model_base(file_name) except BaseException as e: @@ -116,14 +154,23 @@ def download_model(self, file_name: str) -> Path: raise e @abstractmethod - def download_model_base(self, file_name: str) -> Path: - """Download model from any source if not cached. Returns path if cached""" + def download_model_base(self, model_name: str) -> Path: + """ + Download model from any source if not cached. + Returns: + pathlib.Path: Path to the downloaded model. + """ - def __call__(self, file_name: str): - return self.download_model(file_name) + def __call__(self, model_name: str): + return self.download_model(model_name) class HuggingFaceCompatibleDownloader(CachedDownloader, ABC): + """ + Downloader for models from HuggingFace Hub. + Private models are not supported. + """ + def __init__( self, name: str = "Huggingface.co", @@ -131,7 +178,10 @@ def __init__( fb_downloader: Optional["CachedDownloader"] = None, ): self.cache_dir = checkpoints_dir + """SHOULD be same for all instances to prevent downloading same model multiple times + Points to ~/.cache/carvekit/checkpoints""" self.base_url = base_url + """MUST be a base url with protocol and domain name to huggingface or another, compatible in terms of models downloading API source""" self._name = name self._fallback_downloader = fb_downloader @@ -143,13 +193,18 @@ def fallback_downloader(self) -> Optional["CachedDownloader"]: def name(self): return self._name - def check_for_existence(self, file_name: str) -> Optional[Path]: - if file_name not in MODELS_URLS.keys(): + def check_for_existence(self, model_name: str) -> Optional[Path]: + """ + Checks if model is already downloaded and cached. Verifies file integrity by checksum. + Returns: + Optional[pathlib.Path]: Path to the cached model if cached. + """ + if model_name not in MODELS_URLS.keys(): raise FileNotFoundError("Unknown model!") path = ( self.cache_dir - / MODELS_URLS[file_name]["repository"].split("/")[1] - / file_name + / MODELS_URLS[model_name]["repository"].split("/")[1] + / model_name ) if not path.exists(): @@ -163,18 +218,18 @@ def check_for_existence(self, file_name: str) -> Optional[Path]: return None return path - def download_model_base(self, file_name: str) -> Path: - cached_path = self.check_for_existence(file_name) + def download_model_base(self, model_name: str) -> Path: + cached_path = self.check_for_existence(model_name) if cached_path is not None: return cached_path else: cached_path = ( self.cache_dir - / MODELS_URLS[file_name]["repository"].split("/")[1] - / file_name + / MODELS_URLS[model_name]["repository"].split("/")[1] + / model_name ) cached_path.parent.mkdir(parents=True, exist_ok=True) - url = MODELS_URLS[file_name] + url = MODELS_URLS[model_name] hugging_face_url = f"{self.base_url}/{url['repository']}/resolve/{url['revision']}/{url['filename']}" try: @@ -190,10 +245,10 @@ def download_model_base(self, file_name: str) -> Path: f.write(chunk) else: if r.status_code == 404: - raise FileNotFoundError(f"Model {file_name} not found!") + raise FileNotFoundError(f"Model {model_name} not found!") else: raise ConnectionError( - f"Error {r.status_code} while downloading model {file_name}!" + f"Error {r.status_code} while downloading model {model_name}!" ) except BaseException as e: if cached_path.exists(): diff --git a/carvekit/utils/fs_utils.py b/carvekit/utils/fs_utils.py index bd6291e..5219af5 100644 --- a/carvekit/utils/fs_utils.py +++ b/carvekit/utils/fs_utils.py @@ -1,6 +1,8 @@ """ Source url: https://github.com/OPHoperHPO/image-background-remove-tool + Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. + License: Apache License 2.0 """ from pathlib import Path @@ -14,9 +16,9 @@ def save_file(output: Optional[Path], input_path: Path, image: Image.Image): Saves an image to the file system Args: - output: Output path [dir or end file] - input_path: Input path of the image - image: Image to be saved. + output (Optional[pathlib.Path]): Output path [dir or end file] + input_path (pathlib.Path): Input path of the image + image (Image.Image): Image to be saved. """ if isinstance(output, Path) and str(output) != "none": if output.is_dir() and output.exists(): diff --git a/carvekit/utils/image_utils.py b/carvekit/utils/image_utils.py index 8b939f5..cb2a538 100644 --- a/carvekit/utils/image_utils.py +++ b/carvekit/utils/image_utils.py @@ -1,6 +1,8 @@ """ Source url: https://github.com/OPHoperHPO/image-background-remove-tool + Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. + License: Apache License 2.0 """ @@ -19,22 +21,22 @@ def to_tensor(x: Any) -> torch.Tensor: Returns a PIL.Image.Image as torch tensor without swap tensor dims. Args: - x: PIL.Image.Image instance + x (PIL.Image.Image): image Returns: - torch.Tensor instance + torch.Tensor: image as torch tensor """ return torch.tensor(np.array(x, copy=True)) def load_image(file: Union[str, pathlib.Path, PIL.Image.Image]) -> PIL.Image.Image: - """Returns a PIL.Image.Image class by string path or pathlib path or PIL.Image.Image instance + """Returns a `PIL.Image.Image` class by string path or `pathlib.Path` or `PIL.Image.Image` instance Args: - file: File path or PIL.Image.Image instance + file (Union[str, pathlib.Path, PIL.Image.Image]): File path or `PIL.Image.Image` instance Returns: - PIL.Image.Image instance + PIL.Image.Image: image instance loaded from `file` location Raises: ValueError: If file not exists or file is directory or file isn't an image or file is not correct PIL Image @@ -54,11 +56,11 @@ def convert_image(image: PIL.Image.Image, mode="RGB") -> PIL.Image.Image: """Performs image conversion to correct color mode Args: - image: PIL.Image.Image instance - mode: Colort Mode to convert + image (PIL.Image.Image): `PIL.Image.Image` instance + mode (str, default=RGB): Color mode to convert Returns: - PIL.Image.Image instance + PIL.Image.Image: converted image Raises: ValueError: If image hasn't convertable color mode, or it is too small @@ -71,10 +73,10 @@ def is_image_valid(image: Union[pathlib.Path, PIL.Image.Image]) -> bool: """This function performs image validation. Args: - image: Path to the image or PIL.Image.Image instance being checked. + image (Union[pathlib.Path, PIL.Image.Image]): Path to the image or `PIL.Image.Image` instance being checked. Returns: - True if image is valid + bool: True if image is valid, False otherwise. Raises: ValueError: If file not a valid image path or image hasn't convertable color mode, or it is too small @@ -92,7 +94,11 @@ def is_image_valid(image: Union[pathlib.Path, PIL.Image.Image]) -> bool: elif isinstance(image, PIL.Image.Image): if not (image.size[0] > 32 and image.size[1] > 32): raise ValueError("Image should be bigger then (32x32) pixels.") - elif image.mode not in ["RGB", "RGBA", "L"]: + elif image.mode not in [ + "RGB", + "RGBA", + "L", + ]: raise ValueError("Wrong image color mode.") else: raise ValueError("Unknown input file type") @@ -106,12 +112,12 @@ def transparency_paste( Inserts an image into another image while maintaining transparency. Args: - bg_img: background image - fg_img: foreground image - box: place to paste + bg_img (PIL.Image.Image): background image + fg_img (PIL.Image.Image): foreground image + box (tuple[int, int]): place to paste Returns: - Background image with pasted foreground image at point or in the specified box + PIL.Image.Image: Background image with pasted foreground image at point or in the specified box """ fg_img_trans = PIL.Image.new("RGBA", bg_img.size) fg_img_trans.paste(fg_img, box, mask=fg_img) @@ -131,15 +137,15 @@ def add_margin( Adds margin to the image. Args: - pil_img: Image that needed to add margin. - top: pixels count at top side - right: pixels count at right side - bottom: pixels count at bottom side - left: pixels count at left side - color: color of margin + pil_img (PIL.Image.Image): Image that needed to add margin. + top (int): pixels count at top side + right (int): pixels count at right side + bottom (int): pixels count at bottom side + left (int): pixels count at left side + color (Tuple[int, int, int, int]): color of margin Returns: - Image with margin. + PIL.Image.Image: Image with margin. """ width, height = pil_img.size new_width = width + right + left diff --git a/carvekit/utils/mask_utils.py b/carvekit/utils/mask_utils.py index 4402036..cd712c8 100644 --- a/carvekit/utils/mask_utils.py +++ b/carvekit/utils/mask_utils.py @@ -1,6 +1,8 @@ """ Source url: https://github.com/OPHoperHPO/image-background-remove-tool + Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. + License: Apache License 2.0 """ import PIL.Image @@ -19,13 +21,13 @@ def composite( https://pymatting.github.io/intro.html#alpha-matting math formula. Args: - device: Processing device - foreground: Image that will be pasted to background image with following alpha mask. - background: Background image - alpha: Alpha Image + foreground (PIL.Image.Image): Image that will be pasted to background image with following alpha mask. + background (PIL.Image.Image): Background image + alpha (PIL.Image.Image): Alpha Image + device (Literal[cpu, cuda]): Processing device Returns: - Composited image as PIL.Image instance. + PIL.Image.Image: Composited image. """ foreground = foreground.convert("RGBA") @@ -58,12 +60,12 @@ def apply_mask( Applies mask to foreground. Args: - device: Processing device. - image: Image with background. - mask: Alpha Channel mask for this image. + image (PIL.Image.Image): Image with background. + mask (PIL.Image.Image): Alpha Channel mask for this image. + device (Literal[cpu, cuda]): Processing device. Returns: - Image without background, where mask was black. + PIL.Image.Image: Image without background, where mask was black. """ background = PIL.Image.new("RGBA", image.size, color=(130, 130, 130, 0)) return composite(image, background, mask, device=device).convert("RGBA") @@ -77,7 +79,7 @@ def extract_alpha_channel(image: PIL.Image.Image) -> PIL.Image.Image: image: RGBA PIL image Returns: - RGBA alpha channel image + PIL.Image.Image: RGBA alpha channel image """ alpha = image.split()[-1] bg = PIL.Image.new("RGBA", image.size, (0, 0, 0, 255)) diff --git a/carvekit/utils/models_utils.py b/carvekit/utils/models_utils.py index da0141d..cdd5329 100644 --- a/carvekit/utils/models_utils.py +++ b/carvekit/utils/models_utils.py @@ -1,6 +1,8 @@ """ Source url: https://github.com/OPHoperHPO/image-background-remove-tool + Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. + License: Apache License 2.0 """ @@ -14,7 +16,7 @@ class EmptyAutocast(object): """ - Empty class for disable any autocasting. + Empty class for any auto-casting disabling. """ def __enter__(self): @@ -34,20 +36,21 @@ def get_precision_autocast( Tuple[autocast, Union[torch.dtype, Any]], ]: """ - Returns precision and autocast settings for given device and fp16 settings. + Returns precision and auto-cast settings for given device and fp16 settings. + Args: - device: Device to get precision and autocast settings for. - fp16: Whether to use fp16 precision. - override_dtype: Override dtype for autocast. + device (Literal[cpu, cuda]): Device to get precision and auto-cast settings for. + fp16 (bool): Whether to use fp16 precision. + override_dtype (bool): Override dtype for auto-cast. Returns: - Autocast object, dtype + Union[Tuple[EmptyAutocast, Union[torch.dtype, Any]],Tuple[autocast, Union[torch.dtype, Any]]]: Autocast object, dtype """ dtype = torch.float32 cache_enabled = None if device == "cpu" and fp16: - warnings.warn('FP16 is not supported on CPU. Using FP32 instead.') + warnings.warn("FP16 is not supported on CPU. Using FP32 instead.") dtype = torch.float32 # TODO: Implement BFP16 on CPU. There are unexpected slowdowns on cpu on a clean environment. @@ -59,7 +62,6 @@ def get_precision_autocast( # torch.bfloat16 # ) # Using bfloat16 for CPU, since autocast is not supported for float16 - if "cuda" in device and fp16: dtype = torch.float16 cache_enabled = True @@ -79,11 +81,12 @@ def get_precision_autocast( def cast_network(network: torch.nn.Module, dtype: torch.dtype): - """Cast network to given dtype + """ + Cast network to given dtype Args: - network: Network to be casted - dtype: Dtype to cast network to + network (torch.nn.Module): Network to be casted + dtype (torch.dtype): Dtype to cast network to """ if dtype == torch.float16: network.half() @@ -95,11 +98,12 @@ def cast_network(network: torch.nn.Module, dtype: torch.dtype): raise ValueError(f"Unknown dtype {dtype}") -def fix_seed(seed=42): - """Sets fixed random seed +def fix_seed(seed: int = 42): + """ + Sets fixed random seed Args: - seed: Random seed to be set + seed (int, default=42): Random seed to be set """ random.seed(seed) torch.manual_seed(seed) diff --git a/carvekit/utils/pool_utils.py b/carvekit/utils/pool_utils.py index ae3b741..8822ea9 100644 --- a/carvekit/utils/pool_utils.py +++ b/carvekit/utils/pool_utils.py @@ -1,39 +1,41 @@ """ Source url: https://github.com/OPHoperHPO/image-background-remove-tool + Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. + License: Apache License 2.0 """ from concurrent.futures import ThreadPoolExecutor -from typing import Any, Iterable +from typing import Any, Iterable, Callable, Collection, List -def thread_pool_processing(func: Any, data: Iterable, workers=18): +def thread_pool_processing(func: Callable[[Any], Any], data: Iterable, workers=18): """ Passes all iterator data through the given function Args: - workers: Count of workers. - func: function to pass data through - data: input iterator + workers (int, default=18): Count of workers. + func (Callable[[Any], Any]): function to pass data through + data (Iterable): input iterator Returns: - function return list + List[Any]: list of results """ with ThreadPoolExecutor(workers) as p: return list(p.map(func, data)) -def batch_generator(iterable, n=1): +def batch_generator(iterable: Collection, n: int = 1) -> Iterable[Collection]: """ Splits any iterable into n-size packets Args: - iterable: iterator - n: size of packets + iterable (Collection): iterator + n (int, default=1): size of packets Returns: - new n-size packet + Iterable[Collection]: new n-size packet """ it = len(iterable) for ndx in range(0, it, n): diff --git a/carvekit/web/schemas/config.py b/carvekit/web/schemas/config.py index 5d47ffc..8b12c02 100644 --- a/carvekit/web/schemas/config.py +++ b/carvekit/web/schemas/config.py @@ -24,20 +24,26 @@ class MLConfig(BaseModel): "u2net", "deeplabv3", "basnet", "tracer_b7" ] = "tracer_b7" """Segmentation Network""" - preprocessing_method: Literal["none", "stub"] = "none" + preprocessing_method: Literal["none", "stub", "autoscene", "auto"] = "autoscene" """Pre-processing Method""" - postprocessing_method: Literal["fba", "none"] = "fba" + postprocessing_method: Literal["fba", "cascade_fba", "none"] = "cascade_fba" """Post-Processing Network""" device: str = "cpu" """Processing device""" + batch_size_pre: int = 5 + """Batch size for preprocessing method""" batch_size_seg: int = 5 """Batch size for segmentation network""" batch_size_matting: int = 1 """Batch size for matting network""" + batch_size_refine: int = 1 + """Batch size for refine network""" seg_mask_size: int = 640 """The size of the input image for the segmentation neural network.""" matting_mask_size: int = 2048 """The size of the input image for the matting neural network.""" + refine_mask_size: int = 900 + """The size of the input image for the refine neural network.""" fp16: bool = False """Use half precision for inference""" trimap_dilation: int = 30 diff --git a/carvekit/web/utils/init_utils.py b/carvekit/web/utils/init_utils.py index f687182..d975e27 100644 --- a/carvekit/web/utils/init_utils.py +++ b/carvekit/web/utils/init_utils.py @@ -1,18 +1,26 @@ +import warnings from os import getenv from typing import Union from loguru import logger +from carvekit.ml.wrap.cascadepsp import CascadePSP +from carvekit.ml.wrap.scene_classifier import SceneClassifier from carvekit.web.schemas.config import WebAPIConfig, MLConfig, AuthConfig + from carvekit.api.interface import Interface +from carvekit.api.autointerface import AutoInterface + from carvekit.ml.wrap.fba_matting import FBAMatting from carvekit.ml.wrap.u2net import U2NET from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 from carvekit.ml.wrap.basnet import BASNET from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 +from carvekit.ml.wrap.yolov4 import SimplifiedYoloV4 -from carvekit.pipelines.postprocessing import MattingMethod -from carvekit.pipelines.preprocessing import PreprocessingStub + +from carvekit.pipelines.postprocessing import MattingMethod, CasMattingMethod +from carvekit.pipelines.preprocessing import PreprocessingStub, AutoScene from carvekit.trimap.generator import TrimapGenerator @@ -36,6 +44,9 @@ def init_config() -> WebAPIConfig: default_config.ml.postprocessing_method, ), device=getenv("CARVEKIT_DEVICE", default_config.ml.device), + batch_size_pre=int( + getenv("CARVEKIT_BATCH_SIZE_PRE", default_config.ml.batch_size_pre) + ), batch_size_seg=int( getenv("CARVEKIT_BATCH_SIZE_SEG", default_config.ml.batch_size_seg) ), @@ -45,6 +56,12 @@ def init_config() -> WebAPIConfig: default_config.ml.batch_size_matting, ) ), + batch_size_refine=int( + getenv( + "CARVEKIT_BATCH_SIZE_REFINE", + default_config.ml.batch_size_refine, + ) + ), seg_mask_size=int( getenv("CARVEKIT_SEG_MASK_SIZE", default_config.ml.seg_mask_size) ), @@ -54,6 +71,12 @@ def init_config() -> WebAPIConfig: default_config.ml.matting_mask_size, ) ), + refine_mask_size=int( + getenv( + "CARVEKIT_REFINE_MASK_SIZE", + default_config.ml.refine_mask_size, + ) + ), fp16=bool(int(getenv("CARVEKIT_FP16", default_config.ml.fp16))), trimap_prob_threshold=int( getenv( @@ -92,74 +115,131 @@ def init_config() -> WebAPIConfig: def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface: if isinstance(config, WebAPIConfig): config = config.ml - if config.segmentation_network == "u2net": - seg_net = U2NET( - device=config.device, - batch_size=config.batch_size_seg, - input_image_size=config.seg_mask_size, - fp16=config.fp16, - ) - elif config.segmentation_network == "deeplabv3": - seg_net = DeepLabV3( - device=config.device, - batch_size=config.batch_size_seg, - input_image_size=config.seg_mask_size, - fp16=config.fp16, + if config.preprocessing_method == "auto": + warnings.warn( + "Preprocessing_method is set to `auto`." + "We will use automatic methods to determine the best methods for your images! " + "Please note that this is not always the best option and all other options will be ignored!" ) - elif config.segmentation_network == "basnet": - seg_net = BASNET( - device=config.device, - batch_size=config.batch_size_seg, - input_image_size=config.seg_mask_size, - fp16=config.fp16, + scene_classifier = SceneClassifier( + device=config.device, batch_size=config.batch_size_pre, fp16=config.fp16 ) - elif config.segmentation_network == "tracer_b7": - seg_net = TracerUniversalB7( - device=config.device, - batch_size=config.batch_size_seg, - input_image_size=config.seg_mask_size, - fp16=config.fp16, + object_classifier = SimplifiedYoloV4( + device=config.device, batch_size=config.batch_size_pre, fp16=config.fp16 ) - else: - seg_net = TracerUniversalB7( - device=config.device, - batch_size=config.batch_size_seg, - input_image_size=config.seg_mask_size, + return AutoInterface( + scene_classifier=scene_classifier, + object_classifier=object_classifier, + segmentation_batch_size=config.batch_size_seg, + postprocessing_batch_size=config.batch_size_matting, + postprocessing_image_size=config.matting_mask_size, + segmentation_device=config.device, + postprocessing_device=config.device, fp16=config.fp16, ) - if config.preprocessing_method == "stub": - preprocessing = PreprocessingStub() - elif config.preprocessing_method == "none": - preprocessing = None else: - preprocessing = None + if config.segmentation_network == "u2net": + seg_net = U2NET( + device=config.device, + batch_size=config.batch_size_seg, + input_image_size=config.seg_mask_size, + fp16=config.fp16, + ) + elif config.segmentation_network == "deeplabv3": + seg_net = DeepLabV3( + device=config.device, + batch_size=config.batch_size_seg, + input_image_size=config.seg_mask_size, + fp16=config.fp16, + ) + elif config.segmentation_network == "basnet": + seg_net = BASNET( + device=config.device, + batch_size=config.batch_size_seg, + input_image_size=config.seg_mask_size, + fp16=config.fp16, + ) + elif config.segmentation_network == "tracer_b7": + seg_net = TracerUniversalB7( + device=config.device, + batch_size=config.batch_size_seg, + input_image_size=config.seg_mask_size, + fp16=config.fp16, + ) + else: + seg_net = TracerUniversalB7( + device=config.device, + batch_size=config.batch_size_seg, + input_image_size=config.seg_mask_size, + fp16=config.fp16, + ) - if config.postprocessing_method == "fba": - fba = FBAMatting( - device=config.device, - batch_size=config.batch_size_matting, - input_tensor_size=config.matting_mask_size, - fp16=config.fp16, - ) - trimap_generator = TrimapGenerator( - prob_threshold=config.trimap_prob_threshold, - kernel_size=config.trimap_dilation, - erosion_iters=config.trimap_erosion, - ) - postprocessing = MattingMethod( - device=config.device, matting_module=fba, trimap_generator=trimap_generator - ) + if config.preprocessing_method == "stub": + preprocessing = PreprocessingStub() + elif config.preprocessing_method == "none": + preprocessing = None + elif config.preprocessing_method == "autoscene": + preprocessing = AutoScene( + scene_classifier=SceneClassifier( + device=config.device, + batch_size=config.batch_size_pre, + fp16=config.fp16, + ) + ) + else: + preprocessing = None - elif config.postprocessing_method == "none": - postprocessing = None - else: - postprocessing = None + if config.postprocessing_method == "fba": + fba = FBAMatting( + device=config.device, + batch_size=config.batch_size_matting, + input_tensor_size=config.matting_mask_size, + fp16=config.fp16, + ) + trimap_generator = TrimapGenerator( + prob_threshold=config.trimap_prob_threshold, + kernel_size=config.trimap_dilation, + erosion_iters=config.trimap_erosion, + ) + postprocessing = MattingMethod( + device=config.device, + matting_module=fba, + trimap_generator=trimap_generator, + ) + elif config.postprocessing_method == "cascade_fba": + cascadepsp = CascadePSP( + device=config.device, + batch_size=config.batch_size_refine, + input_tensor_size=config.refine_mask_size, + fp16=config.fp16, + ) + fba = FBAMatting( + device=config.device, + batch_size=config.batch_size_matting, + input_tensor_size=config.matting_mask_size, + fp16=config.fp16, + ) + trimap_generator = TrimapGenerator( + prob_threshold=config.trimap_prob_threshold, + kernel_size=config.trimap_dilation, + erosion_iters=config.trimap_erosion, + ) + postprocessing = CasMattingMethod( + device=config.device, + matting_module=fba, + trimap_generator=trimap_generator, + refining_module=cascadepsp, + ) + elif config.postprocessing_method == "none": + postprocessing = None + else: + postprocessing = None - interface = Interface( - pre_pipe=preprocessing, - post_pipe=postprocessing, - seg_pipe=seg_net, - device=config.device, - ) + interface = Interface( + pre_pipe=preprocessing, + post_pipe=postprocessing, + seg_pipe=seg_net, + device=config.device, + ) return interface diff --git a/conftest.py b/conftest.py index f328d35..3f75d22 100644 --- a/conftest.py +++ b/conftest.py @@ -23,6 +23,7 @@ from carvekit.ml.wrap.fba_matting import FBAMatting from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 +from carvekit.ml.wrap.scene_classifier import SceneClassifier @pytest.fixture() @@ -37,6 +38,15 @@ def u2net_model() -> Callable[[bool], U2NET]: ) +@pytest.fixture() +def scene_classifier_model() -> Callable[[bool], SceneClassifier]: + return lambda fb16: SceneClassifier( + device="cuda" if torch.cuda.is_available() else "cpu", + batch_size=5, + fp16=fb16, + ) + + @pytest.fixture() def tracer_model() -> Callable[[bool], TracerUniversalB7]: return lambda fb16: TracerUniversalB7( diff --git a/docker-compose.cpu.yml b/docker-compose.cpu.yml index 1fe3f5a..bcaa647 100644 --- a/docker-compose.cpu.yml +++ b/docker-compose.cpu.yml @@ -7,13 +7,16 @@ services: - CARVEKIT_PORT=5000 - CARVEKIT_HOST=0.0.0.0 - CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3 - - CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub - - CARVEKIT_POSTPROCESSING_METHOD=fba # can be none, fba + - CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub, autoscene, auto + - CARVEKIT_POSTPROCESSING_METHOD=cascade_fba # can be none, fba, cascade_fba - CARVEKIT_DEVICE=cpu # can be cuda (req. cuda docker image), cpu + - CARVEKIT_BATCH_SIZE_PRE=5 # Number of images processed per one preprocessing method call. NOT USED IF WEB API IS USED - CARVEKIT_BATCH_SIZE_SEG=5 # Number of images processed per one segmentation nn call. NOT USED IF WEB API IS USED - CARVEKIT_BATCH_SIZE_MATTING=1 # Number of images processed per one matting nn call. NOT USED IF WEB API IS USED + - CARVEKIT_BATCH_SIZE_REFINE=1 # Number of images processed per one refine nn call. NOT USED IF WEB API IS USED - CARVEKIT_SEG_MASK_SIZE=640 # The size of the input image for the segmentation neural network. - CARVEKIT_MATTING_MASK_SIZE=2048 # The size of the input image for the matting neural network. + - CARVEKIT_REFINE_MASK_SIZE=900 # The size of the input image for the refine neural network. - CARVEKIT_FP16=0 # Enables FP16 mode (Only CUDA at the moment) - CARVEKIT_TRIMAP_PROB_THRESHOLD=231 # Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied - CARVEKIT_TRIMAP_DILATION=30 # The size of the offset radius from the object mask in pixels when forming an unknown area diff --git a/docker-compose.cuda.yml b/docker-compose.cuda.yml index 8308594..f90d9a2 100644 --- a/docker-compose.cuda.yml +++ b/docker-compose.cuda.yml @@ -7,13 +7,16 @@ services: - CARVEKIT_PORT=5000 - CARVEKIT_HOST=0.0.0.0 - CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3 - - CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub - - CARVEKIT_POSTPROCESSING_METHOD=fba # can be none, fba + - CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub, autoscene, auto + - CARVEKIT_POSTPROCESSING_METHOD=cascade_fba # can be none, fba, cascade_fba - CARVEKIT_DEVICE=cuda # can be cuda (req. cuda docker image), cpu + - CARVEKIT_BATCH_SIZE_PRE=5 # Number of images processed per one preprocessing method call. NOT USED IF WEB API IS USED - CARVEKIT_BATCH_SIZE_SEG=5 # Number of images processed per one segmentation nn call. NOT USED IF WEB API IS USED - CARVEKIT_BATCH_SIZE_MATTING=1 # Number of images processed per one matting nn call. NOT USED IF WEB API IS USED + - CARVEKIT_BATCH_SIZE_REFINE=1 # Number of images processed per one refine nn call. NOT USED IF WEB API IS USED - CARVEKIT_SEG_MASK_SIZE=640 # The size of the input image for the segmentation neural network. - CARVEKIT_MATTING_MASK_SIZE=2048 # The size of the input image for the matting neural network. + - CARVEKIT_REFINE_MASK_SIZE=900 # The size of the input image for the refine neural network. - CARVEKIT_FP16=0 # Enables FP16 mode (Only CUDA at the moment) - CARVEKIT_TRIMAP_PROB_THRESHOLD=231 # Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied - CARVEKIT_TRIMAP_DILATION=30 # The size of the offset radius from the object mask in pixels when forming an unknown area diff --git a/docs/CREDITS.md b/docs/CREDITS.md index c544c65..337f9d0 100644 --- a/docs/CREDITS.md +++ b/docs/CREDITS.md @@ -24,3 +24,5 @@ All images are copyrighted by their authors. 10. https://arxiv.org/abs/1703.06870 11. https://github.com/Karel911/TRACER 12. https://arxiv.org/abs/2112.07380 +13. https://github.com/hkchengrex/CascadePSP + diff --git a/docs/api/__init__.html b/docs/api/__init__.html new file mode 100644 index 0000000..bbcbd85 --- /dev/null +++ b/docs/api/__init__.html @@ -0,0 +1,48 @@ + + + + + + +__init__ API documentation + + + + + + + + + + + +
+
+
+

Module __init__

+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/basnet.html b/docs/api/basnet.html new file mode 100644 index 0000000..acb0232 --- /dev/null +++ b/docs/api/basnet.html @@ -0,0 +1,469 @@ + + + + + + +basnet API documentation + + + + + + + + + + + +
+
+
+

Module basnet

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+import pathlib
+from typing import Union, List
+
+import PIL
+import numpy as np
+import torch
+from PIL import Image
+
+from carvekit.ml.arch.basnet.basnet import BASNet
+from carvekit.ml.files.models_loc import basnet_pretrained
+from carvekit.utils.image_utils import convert_image, load_image
+from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
+
+__all__ = ["BASNET"]
+
+
+class BASNET(BASNet):
+    """BASNet model interface"""
+
+    def __init__(
+        self,
+        device="cpu",
+        input_image_size: Union[List[int], int] = 320,
+        batch_size: int = 10,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the BASNET model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (Union[List[int], int], default=320): input image size
+            batch_size (int, default=10): the number of images that the neural network processes in one run
+            load_pretrained (bool, default=True): loading pretrained model
+            fp16 (bool, default=True): use fp16 precision **not supported at this moment**
+        """
+        super(BASNET, self).__init__(n_channels=3, n_classes=1)
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.to(device)
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(basnet_pretrained(), map_location=self.device)
+            )
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.Tensor: input for neural network
+
+        """
+        resized = data.resize(self.input_image_size)
+        # noinspection PyTypeChecker
+        resized_arr = np.array(resized, dtype=np.float64)
+        temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+        if np.max(resized_arr) != 0:
+            resized_arr /= np.max(resized_arr)
+        temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+        temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+        temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+        temp_image = temp_image.transpose((2, 0, 1))
+        temp_image = np.expand_dims(temp_image, 0)
+        return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        data = data.unsqueeze(0)
+        mask = data[:, 0, :, :]
+        ma = torch.max(mask)  # Normalizes prediction
+        mi = torch.min(mask)
+        predict = ((mask - mi) / (ma - mi)).squeeze()
+        predict_np = predict.cpu().data.numpy() * 255
+        mask = Image.fromarray(predict_np).convert("L")
+        mask = mask.resize(original_image.size, resample=3)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images through neural network and returns segmentation masks as `PIL.Image.Image` instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images, as `PIL.Image.Image` instances
+
+        """
+        collect_masks = []
+        for image_batch in batch_generator(images, self.batch_size):
+            converted_images = thread_pool_processing(
+                lambda x: convert_image(load_image(x)), image_batch
+            )
+            batches = torch.vstack(
+                thread_pool_processing(self.data_preprocessing, converted_images)
+            )
+            with torch.no_grad():
+                batches = batches.to(self.device)
+                masks, d2, d3, d4, d5, d6, d7, d8 = super(BASNET, self).__call__(
+                    batches
+                )
+                masks_cpu = masks.cpu()
+                del d2, d3, d4, d5, d6, d7, d8, batches, masks
+            masks = thread_pool_processing(
+                lambda x: self.data_postprocessing(masks_cpu[x], converted_images[x]),
+                range(len(converted_images)),
+            )
+            collect_masks += masks
+        return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class BASNET +(device='cpu', input_image_size:Β Union[List[int],Β int]Β =Β 320, batch_size:Β intΒ =Β 10, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False) +
+
+

BASNet model interface

+

Initialize the BASNET model

+

Args

+
+
device : Literal[cpu, cuda], default=cpu
+
processing device
+
input_image_size : Union[List[int], int], default=320
+
input image size
+
batch_size : int, default=10
+
the number of images that the neural network processes in one run
+
load_pretrained : bool, default=True
+
loading pretrained model
+
fp16 : bool, default=True
+
use fp16 precision not supported at this moment
+
+
+ +Expand source code + +
class BASNET(BASNet):
+    """BASNet model interface"""
+
+    def __init__(
+        self,
+        device="cpu",
+        input_image_size: Union[List[int], int] = 320,
+        batch_size: int = 10,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the BASNET model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (Union[List[int], int], default=320): input image size
+            batch_size (int, default=10): the number of images that the neural network processes in one run
+            load_pretrained (bool, default=True): loading pretrained model
+            fp16 (bool, default=True): use fp16 precision **not supported at this moment**
+        """
+        super(BASNET, self).__init__(n_channels=3, n_classes=1)
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.to(device)
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(basnet_pretrained(), map_location=self.device)
+            )
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.Tensor: input for neural network
+
+        """
+        resized = data.resize(self.input_image_size)
+        # noinspection PyTypeChecker
+        resized_arr = np.array(resized, dtype=np.float64)
+        temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+        if np.max(resized_arr) != 0:
+            resized_arr /= np.max(resized_arr)
+        temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+        temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+        temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+        temp_image = temp_image.transpose((2, 0, 1))
+        temp_image = np.expand_dims(temp_image, 0)
+        return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        data = data.unsqueeze(0)
+        mask = data[:, 0, :, :]
+        ma = torch.max(mask)  # Normalizes prediction
+        mi = torch.min(mask)
+        predict = ((mask - mi) / (ma - mi)).squeeze()
+        predict_np = predict.cpu().data.numpy() * 255
+        mask = Image.fromarray(predict_np).convert("L")
+        mask = mask.resize(original_image.size, resample=3)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images through neural network and returns segmentation masks as `PIL.Image.Image` instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images, as `PIL.Image.Image` instances
+
+        """
+        collect_masks = []
+        for image_batch in batch_generator(images, self.batch_size):
+            converted_images = thread_pool_processing(
+                lambda x: convert_image(load_image(x)), image_batch
+            )
+            batches = torch.vstack(
+                thread_pool_processing(self.data_preprocessing, converted_images)
+            )
+            with torch.no_grad():
+                batches = batches.to(self.device)
+                masks, d2, d3, d4, d5, d6, d7, d8 = super(BASNET, self).__call__(
+                    batches
+                )
+                masks_cpu = masks.cpu()
+                del d2, d3, d4, d5, d6, d7, d8, batches, masks
+            masks = thread_pool_processing(
+                lambda x: self.data_postprocessing(masks_cpu[x], converted_images[x]),
+                range(len(converted_images)),
+            )
+            collect_masks += masks
+        return collect_masks
+
+

Ancestors

+
    +
  • BASNet
  • +
  • torch.nn.modules.module.Module
  • +
+

Static methods

+
+
+def data_postprocessing(data:Β torch.Tensor, original_image:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data : torch.Tensor
+
output data from neural network
+
original_image : PIL.Image.Image
+
input image which was used for predicted data
+
+

Returns

+
+
PIL.Image.Image
+
Segmentation mask as PIL Image instance
+
+
+ +Expand source code + +
@staticmethod
+def data_postprocessing(
+    data: torch.Tensor, original_image: PIL.Image.Image
+) -> PIL.Image.Image:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data (torch.Tensor): output data from neural network
+        original_image (PIL.Image.Image): input image which was used for predicted data
+
+    Returns:
+        PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+    """
+    data = data.unsqueeze(0)
+    mask = data[:, 0, :, :]
+    ma = torch.max(mask)  # Normalizes prediction
+    mi = torch.min(mask)
+    predict = ((mask - mi) / (ma - mi)).squeeze()
+    predict_np = predict.cpu().data.numpy() * 255
+    mask = Image.fromarray(predict_np).convert("L")
+    mask = mask.resize(original_image.size, resample=3)
+    return mask
+
+
+
+

Methods

+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.Tensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data : PIL.Image.Image
+
input image
+
+

Returns

+
+
torch.Tensor
+
input for neural network
+
+
+ +Expand source code + +
def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data (PIL.Image.Image): input image
+
+    Returns:
+        torch.Tensor: input for neural network
+
+    """
+    resized = data.resize(self.input_image_size)
+    # noinspection PyTypeChecker
+    resized_arr = np.array(resized, dtype=np.float64)
+    temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+    if np.max(resized_arr) != 0:
+        resized_arr /= np.max(resized_arr)
+    temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+    temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+    temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+    temp_image = temp_image.transpose((2, 0, 1))
+    temp_image = np.expand_dims(temp_image, 0)
+    return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/api/autointerface.html b/docs/api/carvekit/api/autointerface.html new file mode 100644 index 0000000..702f700 --- /dev/null +++ b/docs/api/carvekit/api/autointerface.html @@ -0,0 +1,696 @@ + + + + + + +carvekit.api.autointerface API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.api.autointerface

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from collections import Counter
+from pathlib import Path
+
+from PIL import Image
+from typing import Union, List, Dict
+
+from carvekit.api.interface import Interface
+from carvekit.ml.wrap.basnet import BASNET
+from carvekit.ml.wrap.cascadepsp import CascadePSP
+from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
+from carvekit.ml.wrap.fba_matting import FBAMatting
+from carvekit.ml.wrap.scene_classifier import SceneClassifier
+from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
+from carvekit.ml.wrap.u2net import U2NET
+from carvekit.ml.wrap.yolov4 import SimplifiedYoloV4
+from carvekit.pipelines.postprocessing import CasMattingMethod
+from carvekit.trimap.generator import TrimapGenerator
+
+__all__ = ["AutoInterface"]
+
+from carvekit.utils.image_utils import load_image
+
+from carvekit.utils.pool_utils import thread_pool_processing
+
+
+class AutoInterface(Interface):
+    def __init__(
+        self,
+        scene_classifier: SceneClassifier,
+        object_classifier: SimplifiedYoloV4,
+        segmentation_batch_size: int = 3,
+        refining_batch_size: int = 1,
+        refining_image_size: int = 900,
+        postprocessing_batch_size: int = 1,
+        postprocessing_image_size: int = 2048,
+        segmentation_device: str = "cpu",
+        postprocessing_device: str = "cpu",
+        fp16=False,
+    ):
+        """
+        Args:
+            scene_classifier: SceneClassifier instance
+            object_classifier: YoloV4_COCO instance
+        """
+        self.scene_classifier = scene_classifier
+        self.object_classifier = object_classifier
+        self.segmentation_batch_size = segmentation_batch_size
+        self.refining_batch_size = refining_batch_size
+        self.refining_image_size = refining_image_size
+        self.postprocessing_batch_size = postprocessing_batch_size
+        self.postprocessing_image_size = postprocessing_image_size
+        self.segmentation_device = segmentation_device
+        self.postprocessing_device = postprocessing_device
+        self.fp16 = fp16
+        super().__init__(
+            seg_pipe=None, post_pipe=None, pre_pipe=None
+        )  # just for compatibility with Interface class
+
+    @staticmethod
+    def select_params_for_net(net: Union[TracerUniversalB7, U2NET, DeepLabV3]):
+        """
+        Selects the parameters for the network depending on the scene
+
+        Args:
+            net: network
+        """
+        if net == TracerUniversalB7:
+            return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
+        elif net == U2NET:
+            return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
+        elif net == DeepLabV3:
+            return {"prob_threshold": 231, "kernel_size": 40, "erosion_iters": 20}
+        elif net == BASNET:
+            return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
+        else:
+            raise ValueError("Unknown network type")
+
+    def select_net(self, scene: str, images_info: List[dict]):
+        # TODO: Update this function, when new networks will be added
+        if scene == "hard":
+            for image_info in images_info:
+                objects = image_info["objects"]
+                if len(objects) == 0:
+                    image_info[
+                        "net"
+                    ] = TracerUniversalB7  # It seems that the image is empty, but we will try to process it
+                    continue
+                obj_counter: Dict = dict(Counter([obj for obj in objects]))
+                # fill empty classes
+                for _tag in self.object_classifier.db:
+                    if _tag not in obj_counter:
+                        obj_counter[_tag] = 0
+
+                non_empty_classes = [obj for obj in obj_counter if obj_counter[obj] > 0]
+
+                if obj_counter["human"] > 0 and len(non_empty_classes) == 1:
+                    # Human only case. Hard Scene? It may be a photo of a person in far/middle distance.
+                    image_info["net"] = TracerUniversalB7
+                    # TODO: will use DeepLabV3+ for this image, it is more suitable for this case,
+                    #  but needs checks for small bbox
+                elif obj_counter["human"] > 0 and len(non_empty_classes) > 1:
+                    # Okay, we have a human without extra hairs and something else. Hard border
+                    image_info["net"] = TracerUniversalB7
+                elif obj_counter["cars"] > 0:
+                    # Cars case
+                    image_info["net"] = TracerUniversalB7
+                elif obj_counter["animals"] > 0:
+                    # Animals case
+                    image_info["net"] = U2NET  # animals should be always in soft scenes
+                else:
+                    # We have no idea what is in the image, so we will try to process it with universal model
+                    image_info["net"] = TracerUniversalB7
+
+        elif scene == "soft":
+            for image_info in images_info:
+                objects = image_info["objects"]
+                if len(objects) == 0:
+                    image_info[
+                        "net"
+                    ] = TracerUniversalB7  # It seems that the image is empty, but we will try to process it
+                    continue
+                obj_counter: Dict = dict(Counter([obj for obj in objects]))
+                # fill empty classes
+                for _tag in self.object_classifier.db:
+                    if _tag not in obj_counter:
+                        obj_counter[_tag] = 0
+
+                non_empty_classes = [obj for obj in obj_counter if obj_counter[obj] > 0]
+
+                if obj_counter["human"] > 0 and len(non_empty_classes) == 1:
+                    # Human only case. It may be a portrait
+                    image_info["net"] = U2NET
+                elif obj_counter["human"] > 0 and len(non_empty_classes) > 1:
+                    # Okay, we have a human with hairs and something else
+                    image_info["net"] = U2NET
+                elif obj_counter["cars"] > 0:
+                    # Cars case.
+                    image_info["net"] = TracerUniversalB7
+                elif obj_counter["animals"] > 0:
+                    # Animals case
+                    image_info["net"] = U2NET  # animals should be always in soft scenes
+                else:
+                    # We have no idea what is in the image, so we will try to process it with universal model
+                    image_info["net"] = TracerUniversalB7
+        elif scene == "digital":
+            for image_info in images_info:  # TODO: not implemented yet
+                image_info[
+                    "net"
+                ] = TracerUniversalB7  # It seems that the image is empty, but we will try to process it
+
+    def __call__(self, images: List[Union[str, Path, Image.Image]]):
+        """
+        Automatically detects the scene and selects the appropriate network for segmentation
+
+        Args:
+            interface: Interface instance
+            images: list of images
+
+        Returns:
+            list of masks
+        """
+        loaded_images = thread_pool_processing(load_image, images)
+
+        scene_analysis = self.scene_classifier(loaded_images)
+        images_objects = self.object_classifier(loaded_images)
+
+        images_per_scene = {}
+        for i, image in enumerate(loaded_images):
+            scene_name = scene_analysis[i][0][0]
+            if scene_name not in images_per_scene:
+                images_per_scene[scene_name] = []
+            images_per_scene[scene_name].append(
+                {"image": image, "objects": images_objects[i]}
+            )
+
+        for scene_name, images_info in list(images_per_scene.items()):
+            self.select_net(scene_name, images_info)
+
+        # groups images by net
+        for scene_name, images_info in list(images_per_scene.items()):
+            groups = {}
+            for image_info in images_info:
+                net = image_info["net"]
+                if net not in groups:
+                    groups[net] = []
+                groups[net].append(image_info)
+            for net, gimages_info in list(groups.items()):
+                sc_images = [image_info["image"] for image_info in gimages_info]
+                masks = net(
+                    device=self.segmentation_device,
+                    batch_size=self.segmentation_batch_size,
+                    fp16=self.fp16,
+                )(sc_images)
+
+                for i, image_info in enumerate(gimages_info):
+                    image_info["mask"] = masks[i]
+
+        cascadepsp = CascadePSP(
+            device=self.postprocessing_device,
+            fp16=self.fp16,
+            input_tensor_size=self.refining_image_size,
+            batch_size=self.refining_batch_size,
+        )
+
+        fba = FBAMatting(
+            device=self.postprocessing_device,
+            batch_size=self.postprocessing_batch_size,
+            input_tensor_size=self.postprocessing_image_size,
+            fp16=self.fp16,
+        )
+        # groups images by net
+        for scene_name, images_info in list(images_per_scene.items()):
+            groups = {}
+            for image_info in images_info:
+                net = image_info["net"]
+                if net not in groups:
+                    groups[net] = []
+                groups[net].append(image_info)
+            for net, gimages_info in list(groups.items()):
+                sc_images = [image_info["image"] for image_info in gimages_info]
+                # noinspection PyArgumentList
+                trimap_generator = TrimapGenerator(**self.select_params_for_net(net))
+                matting_method = CasMattingMethod(
+                    refining_module=cascadepsp,
+                    matting_module=fba,
+                    trimap_generator=trimap_generator,
+                    device=self.postprocessing_device,
+                )
+                masks = [image_info["mask"] for image_info in gimages_info]
+                result = matting_method(sc_images, masks)
+
+                for i, image_info in enumerate(gimages_info):
+                    image_info["result"] = result[i]
+
+        # Reconstructing the original order of image
+        result = []
+        for image in loaded_images:
+            for scene_name, images_info in list(images_per_scene.items()):
+                for image_info in images_info:
+                    if image_info["image"] == image:
+                        result.append(image_info["result"])
+                        break
+        if len(result) != len(images):
+            raise RuntimeError(
+                "Something went wrong with restoring original order. Please report this bug."
+            )
+        return result
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class AutoInterface +(scene_classifier:Β SceneClassifier, object_classifier:Β SimplifiedYoloV4, segmentation_batch_size:Β intΒ =Β 3, refining_batch_size:Β intΒ =Β 1, refining_image_size:Β intΒ =Β 900, postprocessing_batch_size:Β intΒ =Β 1, postprocessing_image_size:Β intΒ =Β 2048, segmentation_device:Β strΒ =Β 'cpu', postprocessing_device:Β strΒ =Β 'cpu', fp16=False) +
+
+

Args

+
+
scene_classifier
+
SceneClassifier instance
+
object_classifier
+
YoloV4_COCO instance
+
+
+ +Expand source code + +
class AutoInterface(Interface):
+    def __init__(
+        self,
+        scene_classifier: SceneClassifier,
+        object_classifier: SimplifiedYoloV4,
+        segmentation_batch_size: int = 3,
+        refining_batch_size: int = 1,
+        refining_image_size: int = 900,
+        postprocessing_batch_size: int = 1,
+        postprocessing_image_size: int = 2048,
+        segmentation_device: str = "cpu",
+        postprocessing_device: str = "cpu",
+        fp16=False,
+    ):
+        """
+        Args:
+            scene_classifier: SceneClassifier instance
+            object_classifier: YoloV4_COCO instance
+        """
+        self.scene_classifier = scene_classifier
+        self.object_classifier = object_classifier
+        self.segmentation_batch_size = segmentation_batch_size
+        self.refining_batch_size = refining_batch_size
+        self.refining_image_size = refining_image_size
+        self.postprocessing_batch_size = postprocessing_batch_size
+        self.postprocessing_image_size = postprocessing_image_size
+        self.segmentation_device = segmentation_device
+        self.postprocessing_device = postprocessing_device
+        self.fp16 = fp16
+        super().__init__(
+            seg_pipe=None, post_pipe=None, pre_pipe=None
+        )  # just for compatibility with Interface class
+
+    @staticmethod
+    def select_params_for_net(net: Union[TracerUniversalB7, U2NET, DeepLabV3]):
+        """
+        Selects the parameters for the network depending on the scene
+
+        Args:
+            net: network
+        """
+        if net == TracerUniversalB7:
+            return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
+        elif net == U2NET:
+            return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
+        elif net == DeepLabV3:
+            return {"prob_threshold": 231, "kernel_size": 40, "erosion_iters": 20}
+        elif net == BASNET:
+            return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
+        else:
+            raise ValueError("Unknown network type")
+
+    def select_net(self, scene: str, images_info: List[dict]):
+        # TODO: Update this function, when new networks will be added
+        if scene == "hard":
+            for image_info in images_info:
+                objects = image_info["objects"]
+                if len(objects) == 0:
+                    image_info[
+                        "net"
+                    ] = TracerUniversalB7  # It seems that the image is empty, but we will try to process it
+                    continue
+                obj_counter: Dict = dict(Counter([obj for obj in objects]))
+                # fill empty classes
+                for _tag in self.object_classifier.db:
+                    if _tag not in obj_counter:
+                        obj_counter[_tag] = 0
+
+                non_empty_classes = [obj for obj in obj_counter if obj_counter[obj] > 0]
+
+                if obj_counter["human"] > 0 and len(non_empty_classes) == 1:
+                    # Human only case. Hard Scene? It may be a photo of a person in far/middle distance.
+                    image_info["net"] = TracerUniversalB7
+                    # TODO: will use DeepLabV3+ for this image, it is more suitable for this case,
+                    #  but needs checks for small bbox
+                elif obj_counter["human"] > 0 and len(non_empty_classes) > 1:
+                    # Okay, we have a human without extra hairs and something else. Hard border
+                    image_info["net"] = TracerUniversalB7
+                elif obj_counter["cars"] > 0:
+                    # Cars case
+                    image_info["net"] = TracerUniversalB7
+                elif obj_counter["animals"] > 0:
+                    # Animals case
+                    image_info["net"] = U2NET  # animals should be always in soft scenes
+                else:
+                    # We have no idea what is in the image, so we will try to process it with universal model
+                    image_info["net"] = TracerUniversalB7
+
+        elif scene == "soft":
+            for image_info in images_info:
+                objects = image_info["objects"]
+                if len(objects) == 0:
+                    image_info[
+                        "net"
+                    ] = TracerUniversalB7  # It seems that the image is empty, but we will try to process it
+                    continue
+                obj_counter: Dict = dict(Counter([obj for obj in objects]))
+                # fill empty classes
+                for _tag in self.object_classifier.db:
+                    if _tag not in obj_counter:
+                        obj_counter[_tag] = 0
+
+                non_empty_classes = [obj for obj in obj_counter if obj_counter[obj] > 0]
+
+                if obj_counter["human"] > 0 and len(non_empty_classes) == 1:
+                    # Human only case. It may be a portrait
+                    image_info["net"] = U2NET
+                elif obj_counter["human"] > 0 and len(non_empty_classes) > 1:
+                    # Okay, we have a human with hairs and something else
+                    image_info["net"] = U2NET
+                elif obj_counter["cars"] > 0:
+                    # Cars case.
+                    image_info["net"] = TracerUniversalB7
+                elif obj_counter["animals"] > 0:
+                    # Animals case
+                    image_info["net"] = U2NET  # animals should be always in soft scenes
+                else:
+                    # We have no idea what is in the image, so we will try to process it with universal model
+                    image_info["net"] = TracerUniversalB7
+        elif scene == "digital":
+            for image_info in images_info:  # TODO: not implemented yet
+                image_info[
+                    "net"
+                ] = TracerUniversalB7  # It seems that the image is empty, but we will try to process it
+
+    def __call__(self, images: List[Union[str, Path, Image.Image]]):
+        """
+        Automatically detects the scene and selects the appropriate network for segmentation
+
+        Args:
+            interface: Interface instance
+            images: list of images
+
+        Returns:
+            list of masks
+        """
+        loaded_images = thread_pool_processing(load_image, images)
+
+        scene_analysis = self.scene_classifier(loaded_images)
+        images_objects = self.object_classifier(loaded_images)
+
+        images_per_scene = {}
+        for i, image in enumerate(loaded_images):
+            scene_name = scene_analysis[i][0][0]
+            if scene_name not in images_per_scene:
+                images_per_scene[scene_name] = []
+            images_per_scene[scene_name].append(
+                {"image": image, "objects": images_objects[i]}
+            )
+
+        for scene_name, images_info in list(images_per_scene.items()):
+            self.select_net(scene_name, images_info)
+
+        # groups images by net
+        for scene_name, images_info in list(images_per_scene.items()):
+            groups = {}
+            for image_info in images_info:
+                net = image_info["net"]
+                if net not in groups:
+                    groups[net] = []
+                groups[net].append(image_info)
+            for net, gimages_info in list(groups.items()):
+                sc_images = [image_info["image"] for image_info in gimages_info]
+                masks = net(
+                    device=self.segmentation_device,
+                    batch_size=self.segmentation_batch_size,
+                    fp16=self.fp16,
+                )(sc_images)
+
+                for i, image_info in enumerate(gimages_info):
+                    image_info["mask"] = masks[i]
+
+        cascadepsp = CascadePSP(
+            device=self.postprocessing_device,
+            fp16=self.fp16,
+            input_tensor_size=self.refining_image_size,
+            batch_size=self.refining_batch_size,
+        )
+
+        fba = FBAMatting(
+            device=self.postprocessing_device,
+            batch_size=self.postprocessing_batch_size,
+            input_tensor_size=self.postprocessing_image_size,
+            fp16=self.fp16,
+        )
+        # groups images by net
+        for scene_name, images_info in list(images_per_scene.items()):
+            groups = {}
+            for image_info in images_info:
+                net = image_info["net"]
+                if net not in groups:
+                    groups[net] = []
+                groups[net].append(image_info)
+            for net, gimages_info in list(groups.items()):
+                sc_images = [image_info["image"] for image_info in gimages_info]
+                # noinspection PyArgumentList
+                trimap_generator = TrimapGenerator(**self.select_params_for_net(net))
+                matting_method = CasMattingMethod(
+                    refining_module=cascadepsp,
+                    matting_module=fba,
+                    trimap_generator=trimap_generator,
+                    device=self.postprocessing_device,
+                )
+                masks = [image_info["mask"] for image_info in gimages_info]
+                result = matting_method(sc_images, masks)
+
+                for i, image_info in enumerate(gimages_info):
+                    image_info["result"] = result[i]
+
+        # Reconstructing the original order of image
+        result = []
+        for image in loaded_images:
+            for scene_name, images_info in list(images_per_scene.items()):
+                for image_info in images_info:
+                    if image_info["image"] == image:
+                        result.append(image_info["result"])
+                        break
+        if len(result) != len(images):
+            raise RuntimeError(
+                "Something went wrong with restoring original order. Please report this bug."
+            )
+        return result
+
+

Ancestors

+ +

Static methods

+
+
+def select_params_for_net(net:Β Union[TracerUniversalB7,Β U2NET,Β DeepLabV3]) +
+
+

Selects the parameters for the network depending on the scene

+

Args

+
+
net
+
network
+
+
+ +Expand source code + +
@staticmethod
+def select_params_for_net(net: Union[TracerUniversalB7, U2NET, DeepLabV3]):
+    """
+    Selects the parameters for the network depending on the scene
+
+    Args:
+        net: network
+    """
+    if net == TracerUniversalB7:
+        return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
+    elif net == U2NET:
+        return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
+    elif net == DeepLabV3:
+        return {"prob_threshold": 231, "kernel_size": 40, "erosion_iters": 20}
+    elif net == BASNET:
+        return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
+    else:
+        raise ValueError("Unknown network type")
+
+
+
+

Methods

+
+
+def select_net(self, scene:Β str, images_info:Β List[dict]) +
+
+
+
+ +Expand source code + +
def select_net(self, scene: str, images_info: List[dict]):
+    # TODO: Update this function, when new networks will be added
+    if scene == "hard":
+        for image_info in images_info:
+            objects = image_info["objects"]
+            if len(objects) == 0:
+                image_info[
+                    "net"
+                ] = TracerUniversalB7  # It seems that the image is empty, but we will try to process it
+                continue
+            obj_counter: Dict = dict(Counter([obj for obj in objects]))
+            # fill empty classes
+            for _tag in self.object_classifier.db:
+                if _tag not in obj_counter:
+                    obj_counter[_tag] = 0
+
+            non_empty_classes = [obj for obj in obj_counter if obj_counter[obj] > 0]
+
+            if obj_counter["human"] > 0 and len(non_empty_classes) == 1:
+                # Human only case. Hard Scene? It may be a photo of a person in far/middle distance.
+                image_info["net"] = TracerUniversalB7
+                # TODO: will use DeepLabV3+ for this image, it is more suitable for this case,
+                #  but needs checks for small bbox
+            elif obj_counter["human"] > 0 and len(non_empty_classes) > 1:
+                # Okay, we have a human without extra hairs and something else. Hard border
+                image_info["net"] = TracerUniversalB7
+            elif obj_counter["cars"] > 0:
+                # Cars case
+                image_info["net"] = TracerUniversalB7
+            elif obj_counter["animals"] > 0:
+                # Animals case
+                image_info["net"] = U2NET  # animals should be always in soft scenes
+            else:
+                # We have no idea what is in the image, so we will try to process it with universal model
+                image_info["net"] = TracerUniversalB7
+
+    elif scene == "soft":
+        for image_info in images_info:
+            objects = image_info["objects"]
+            if len(objects) == 0:
+                image_info[
+                    "net"
+                ] = TracerUniversalB7  # It seems that the image is empty, but we will try to process it
+                continue
+            obj_counter: Dict = dict(Counter([obj for obj in objects]))
+            # fill empty classes
+            for _tag in self.object_classifier.db:
+                if _tag not in obj_counter:
+                    obj_counter[_tag] = 0
+
+            non_empty_classes = [obj for obj in obj_counter if obj_counter[obj] > 0]
+
+            if obj_counter["human"] > 0 and len(non_empty_classes) == 1:
+                # Human only case. It may be a portrait
+                image_info["net"] = U2NET
+            elif obj_counter["human"] > 0 and len(non_empty_classes) > 1:
+                # Okay, we have a human with hairs and something else
+                image_info["net"] = U2NET
+            elif obj_counter["cars"] > 0:
+                # Cars case.
+                image_info["net"] = TracerUniversalB7
+            elif obj_counter["animals"] > 0:
+                # Animals case
+                image_info["net"] = U2NET  # animals should be always in soft scenes
+            else:
+                # We have no idea what is in the image, so we will try to process it with universal model
+                image_info["net"] = TracerUniversalB7
+    elif scene == "digital":
+        for image_info in images_info:  # TODO: not implemented yet
+            image_info[
+                "net"
+            ] = TracerUniversalB7  # It seems that the image is empty, but we will try to process it
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/api/high.html b/docs/api/carvekit/api/high.html new file mode 100644 index 0000000..a14e534 --- /dev/null +++ b/docs/api/carvekit/api/high.html @@ -0,0 +1,382 @@ + + + + + + +carvekit.api.high API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.api.high

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+import warnings
+
+from carvekit.api.interface import Interface
+from carvekit.ml.wrap.fba_matting import FBAMatting
+from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
+from carvekit.ml.wrap.cascadepsp import CascadePSP
+from carvekit.ml.wrap.scene_classifier import SceneClassifier
+from carvekit.pipelines.preprocessing import AutoScene
+from carvekit.ml.wrap.u2net import U2NET
+from carvekit.pipelines.postprocessing import CasMattingMethod
+from carvekit.trimap.generator import TrimapGenerator
+
+
+class HiInterface(Interface):
+    def __init__(
+        self,
+        object_type: str = "auto",
+        batch_size_pre=5,
+        batch_size_seg=2,
+        batch_size_matting=1,
+        batch_size_refine=1,
+        device="cpu",
+        seg_mask_size=640,
+        matting_mask_size=2048,
+        refine_mask_size=900,
+        trimap_prob_threshold=231,
+        trimap_dilation=30,
+        trimap_erosion_iters=5,
+        fp16=False,
+    ):
+        """
+        Initializes High Level interface.
+
+        Args:
+            object_type (str, default=object): Interest object type. Can be "object" or "hairs-like".
+            matting_mask_size (int, default=2048):  The size of the input image for the matting neural network.
+            seg_mask_size (int, default=640): The size of the input image for the segmentation neural network.
+            batch_size_pre (int, default=5: Number of images processed per one preprocessing method call.
+            batch_size_seg (int, default=2): Number of images processed per one segmentation neural network call.
+            batch_size_matting (int, matting=1): Number of images processed per one matting neural network call.
+            device (Literal[cpu, cuda], default=cpu): Processing device
+            fp16 (bool, default=False): Use half precision. Reduce memory usage and increase speed.
+            .. CAUTION:: ⚠️ **Experimental support**
+            trimap_prob_threshold (int, default=231): Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied
+            trimap_dilation (int, default=30): The size of the offset radius from the object mask in pixels when forming an unknown area
+            trimap_erosion_iters (int, default=5): The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area
+            refine_mask_size (int, default=900): The size of the input image for the refinement neural network.
+            batch_size_refine (int, default=1): Number of images processed per one refinement neural network call.
+
+
+        .. NOTE::
+            1. Changing seg_mask_size may cause an `out-of-memory` error if the value is too large, and it may also
+            result in reduced precision. I do not recommend changing this value. You can change `matting_mask_size` in
+            range from `(1024 to 4096)` to improve object edge refining quality, but it will cause extra large RAM and
+            video memory consume. Also, you can change batch size to accelerate background removal, but it also causes
+            extra large video memory consume, if value is too big.
+            2. Changing `trimap_prob_threshold`, `trimap_kernel_size`, `trimap_erosion_iters` may improve object edge
+            refining quality.
+        """
+        preprocess_pipeline = None
+
+        if object_type == "object":
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+        elif object_type == "hairs-like":
+            self._segnet = U2NET(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+        elif object_type == "auto":
+            # Using Tracer by default,
+            # but it will dynamically switch to other if needed
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+            self._scene_classifier = SceneClassifier(
+                device=device, fp16=fp16, batch_size=batch_size_pre
+            )
+            preprocess_pipeline = AutoScene(scene_classifier=self._scene_classifier)
+
+        else:
+            warnings.warn(
+                f"Unknown object type: {object_type}. Using default object type: object"
+            )
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+
+        self._cascade_psp = CascadePSP(
+            device=device,
+            batch_size=batch_size_refine,
+            input_tensor_size=refine_mask_size,
+            fp16=fp16,
+        )
+        self._fba = FBAMatting(
+            batch_size=batch_size_matting,
+            device=device,
+            input_tensor_size=matting_mask_size,
+            fp16=fp16,
+        )
+        self._trimap_generator = TrimapGenerator(
+            prob_threshold=trimap_prob_threshold,
+            kernel_size=trimap_dilation,
+            erosion_iters=trimap_erosion_iters,
+        )
+        super(HiInterface, self).__init__(
+            pre_pipe=preprocess_pipeline,
+            seg_pipe=self._segnet,
+            post_pipe=CasMattingMethod(
+                refining_module=self._cascade_psp,
+                matting_module=self._fba,
+                trimap_generator=self._trimap_generator,
+                device=device,
+            ),
+            device=device,
+        )
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class HiInterface +(object_type:Β strΒ =Β 'auto', batch_size_pre=5, batch_size_seg=2, batch_size_matting=1, batch_size_refine=1, device='cpu', seg_mask_size=640, matting_mask_size=2048, refine_mask_size=900, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, fp16=False) +
+
+

Initializes High Level interface.

+

Args

+
+
object_type : str, default=object
+
Interest object type. Can be "object" or "hairs-like".
+
matting_mask_size : int, default=2048
+
The size of the input image for the matting neural network.
+
seg_mask_size : int, default=640
+
The size of the input image for the segmentation neural network.
+
batch_size_pre (int, default=5: Number of images processed per one preprocessing method call.
+
batch_size_seg : int, default=2
+
Number of images processed per one segmentation neural network call.
+
batch_size_matting : int, matting=1
+
Number of images processed per one matting neural network call.
+
device : Literal[cpu, cuda], default=cpu
+
Processing device
+
fp16 : bool, default=False
+
Use half precision. Reduce memory usage and increase speed.
+
+
+

Caution: ⚠️ Experimental support

+
+
+
trimap_prob_threshold : int, default=231
+
Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied
+
trimap_dilation : int, default=30
+
The size of the offset radius from the object mask in pixels when forming an unknown area
+
trimap_erosion_iters : int, default=5
+
The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area
+
refine_mask_size : int, default=900
+
The size of the input image for the refinement neural network.
+
batch_size_refine : int, default=1
+
Number of images processed per one refinement neural network call.
+
+
+

Note

+
    +
  1. Changing seg_mask_size may cause an out-of-memory error if the value is too large, and it may also +result in reduced precision. I do not recommend changing this value. You can change matting_mask_size in +range from (1024 to 4096) to improve object edge refining quality, but it will cause extra large RAM and +video memory consume. Also, you can change batch size to accelerate background removal, but it also causes +extra large video memory consume, if value is too big.
  2. +
  3. Changing trimap_prob_threshold, trimap_kernel_size, trimap_erosion_iters may improve object edge +refining quality.
  4. +
+
+
+ +Expand source code + +
class HiInterface(Interface):
+    def __init__(
+        self,
+        object_type: str = "auto",
+        batch_size_pre=5,
+        batch_size_seg=2,
+        batch_size_matting=1,
+        batch_size_refine=1,
+        device="cpu",
+        seg_mask_size=640,
+        matting_mask_size=2048,
+        refine_mask_size=900,
+        trimap_prob_threshold=231,
+        trimap_dilation=30,
+        trimap_erosion_iters=5,
+        fp16=False,
+    ):
+        """
+        Initializes High Level interface.
+
+        Args:
+            object_type (str, default=object): Interest object type. Can be "object" or "hairs-like".
+            matting_mask_size (int, default=2048):  The size of the input image for the matting neural network.
+            seg_mask_size (int, default=640): The size of the input image for the segmentation neural network.
+            batch_size_pre (int, default=5: Number of images processed per one preprocessing method call.
+            batch_size_seg (int, default=2): Number of images processed per one segmentation neural network call.
+            batch_size_matting (int, matting=1): Number of images processed per one matting neural network call.
+            device (Literal[cpu, cuda], default=cpu): Processing device
+            fp16 (bool, default=False): Use half precision. Reduce memory usage and increase speed.
+            .. CAUTION:: ⚠️ **Experimental support**
+            trimap_prob_threshold (int, default=231): Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied
+            trimap_dilation (int, default=30): The size of the offset radius from the object mask in pixels when forming an unknown area
+            trimap_erosion_iters (int, default=5): The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area
+            refine_mask_size (int, default=900): The size of the input image for the refinement neural network.
+            batch_size_refine (int, default=1): Number of images processed per one refinement neural network call.
+
+
+        .. NOTE::
+            1. Changing seg_mask_size may cause an `out-of-memory` error if the value is too large, and it may also
+            result in reduced precision. I do not recommend changing this value. You can change `matting_mask_size` in
+            range from `(1024 to 4096)` to improve object edge refining quality, but it will cause extra large RAM and
+            video memory consume. Also, you can change batch size to accelerate background removal, but it also causes
+            extra large video memory consume, if value is too big.
+            2. Changing `trimap_prob_threshold`, `trimap_kernel_size`, `trimap_erosion_iters` may improve object edge
+            refining quality.
+        """
+        preprocess_pipeline = None
+
+        if object_type == "object":
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+        elif object_type == "hairs-like":
+            self._segnet = U2NET(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+        elif object_type == "auto":
+            # Using Tracer by default,
+            # but it will dynamically switch to other if needed
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+            self._scene_classifier = SceneClassifier(
+                device=device, fp16=fp16, batch_size=batch_size_pre
+            )
+            preprocess_pipeline = AutoScene(scene_classifier=self._scene_classifier)
+
+        else:
+            warnings.warn(
+                f"Unknown object type: {object_type}. Using default object type: object"
+            )
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+
+        self._cascade_psp = CascadePSP(
+            device=device,
+            batch_size=batch_size_refine,
+            input_tensor_size=refine_mask_size,
+            fp16=fp16,
+        )
+        self._fba = FBAMatting(
+            batch_size=batch_size_matting,
+            device=device,
+            input_tensor_size=matting_mask_size,
+            fp16=fp16,
+        )
+        self._trimap_generator = TrimapGenerator(
+            prob_threshold=trimap_prob_threshold,
+            kernel_size=trimap_dilation,
+            erosion_iters=trimap_erosion_iters,
+        )
+        super(HiInterface, self).__init__(
+            pre_pipe=preprocess_pipeline,
+            seg_pipe=self._segnet,
+            post_pipe=CasMattingMethod(
+                refining_module=self._cascade_psp,
+                matting_module=self._fba,
+                trimap_generator=self._trimap_generator,
+                device=device,
+            ),
+            device=device,
+        )
+
+

Ancestors

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/api/index.html b/docs/api/carvekit/api/index.html new file mode 100644 index 0000000..906216a --- /dev/null +++ b/docs/api/carvekit/api/index.html @@ -0,0 +1,79 @@ + + + + + + +carvekit.api API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.api

+
+
+
+
+

Sub-modules

+
+
carvekit.api.autointerface
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
carvekit.api.high
+
+ +
+
carvekit.api.interface
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/api/interface.html b/docs/api/carvekit/api/interface.html new file mode 100644 index 0000000..3929695 --- /dev/null +++ b/docs/api/carvekit/api/interface.html @@ -0,0 +1,245 @@ + + + + + + +carvekit.api.interface API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.api.interface

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from pathlib import Path
+from typing import Union, List, Optional
+
+from PIL import Image
+
+from carvekit.ml.wrap.basnet import BASNET
+from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
+from carvekit.ml.wrap.u2net import U2NET
+from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
+from carvekit.pipelines.preprocessing import PreprocessingStub, AutoScene
+from carvekit.pipelines.postprocessing import MattingMethod, CasMattingMethod
+from carvekit.utils.image_utils import load_image
+from carvekit.utils.mask_utils import apply_mask
+from carvekit.utils.pool_utils import thread_pool_processing
+
+
+class Interface:
+    def __init__(
+        self,
+        seg_pipe: Optional[Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]],
+        pre_pipe: Optional[Union[PreprocessingStub, AutoScene]] = None,
+        post_pipe: Optional[Union[MattingMethod, CasMattingMethod]] = None,
+        device="cpu",
+    ):
+        """
+        Initializes an object for interacting with pipelines and other components of the CarveKit framework.
+
+        Args:
+            pre_pipe (Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]): Initialized pre-processing pipeline object
+            seg_pipe (Optional[Union[PreprocessingStub]]): Initialized segmentation network object
+            post_pipe (Optional[Union[MattingMethod]]): Initialized postprocessing pipeline object
+            device (Literal[cpu, cuda], default=cpu): The processing device that will be used to apply the masks to the images.
+        """
+        self.device = device
+        self.preprocessing_pipeline = pre_pipe
+        self.segmentation_pipeline = seg_pipe
+        self.postprocessing_pipeline = post_pipe
+
+    def __call__(
+        self, images: List[Union[str, Path, Image.Image]]
+    ) -> List[Image.Image]:
+        """
+        Removes the background from the specified images.
+
+        Args:
+            images: list of input images
+
+        Returns:
+            List of images without background as PIL.Image.Image instances
+        """
+        if self.segmentation_pipeline is None:
+            raise ValueError(
+                "Segmentation pipeline is not initialized."
+                "Override the class or pass the pipeline to the constructor."
+            )
+        images = thread_pool_processing(load_image, images)
+        if self.preprocessing_pipeline is not None:
+            masks: List[Image.Image] = self.preprocessing_pipeline(
+                interface=self, images=images
+            )
+        else:
+            masks: List[Image.Image] = self.segmentation_pipeline(images=images)
+
+        if self.postprocessing_pipeline is not None:
+            images: List[Image.Image] = self.postprocessing_pipeline(
+                images=images, masks=masks
+            )
+        else:
+            images = list(
+                map(
+                    lambda x: apply_mask(
+                        image=images[x], mask=masks[x], device=self.device
+                    ),
+                    range(len(images)),
+                )
+            )
+        return images
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class Interface +(seg_pipe:Β Union[U2NET,Β BASNET,Β DeepLabV3,Β TracerUniversalB7,Β ForwardRef(None)], pre_pipe:Β Union[PreprocessingStub,Β AutoScene,Β ForwardRef(None)]Β =Β None, post_pipe:Β Union[MattingMethod,Β CasMattingMethod,Β ForwardRef(None)]Β =Β None, device='cpu') +
+
+

Initializes an object for interacting with pipelines and other components of the CarveKit framework.

+

Args

+
+
pre_pipe : Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]
+
Initialized pre-processing pipeline object
+
seg_pipe : Optional[Union[PreprocessingStub]]
+
Initialized segmentation network object
+
post_pipe : Optional[Union[MattingMethod]]
+
Initialized postprocessing pipeline object
+
device : Literal[cpu, cuda], default=cpu
+
The processing device that will be used to apply the masks to the images.
+
+
+ +Expand source code + +
class Interface:
+    def __init__(
+        self,
+        seg_pipe: Optional[Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]],
+        pre_pipe: Optional[Union[PreprocessingStub, AutoScene]] = None,
+        post_pipe: Optional[Union[MattingMethod, CasMattingMethod]] = None,
+        device="cpu",
+    ):
+        """
+        Initializes an object for interacting with pipelines and other components of the CarveKit framework.
+
+        Args:
+            pre_pipe (Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]): Initialized pre-processing pipeline object
+            seg_pipe (Optional[Union[PreprocessingStub]]): Initialized segmentation network object
+            post_pipe (Optional[Union[MattingMethod]]): Initialized postprocessing pipeline object
+            device (Literal[cpu, cuda], default=cpu): The processing device that will be used to apply the masks to the images.
+        """
+        self.device = device
+        self.preprocessing_pipeline = pre_pipe
+        self.segmentation_pipeline = seg_pipe
+        self.postprocessing_pipeline = post_pipe
+
+    def __call__(
+        self, images: List[Union[str, Path, Image.Image]]
+    ) -> List[Image.Image]:
+        """
+        Removes the background from the specified images.
+
+        Args:
+            images: list of input images
+
+        Returns:
+            List of images without background as PIL.Image.Image instances
+        """
+        if self.segmentation_pipeline is None:
+            raise ValueError(
+                "Segmentation pipeline is not initialized."
+                "Override the class or pass the pipeline to the constructor."
+            )
+        images = thread_pool_processing(load_image, images)
+        if self.preprocessing_pipeline is not None:
+            masks: List[Image.Image] = self.preprocessing_pipeline(
+                interface=self, images=images
+            )
+        else:
+            masks: List[Image.Image] = self.segmentation_pipeline(images=images)
+
+        if self.postprocessing_pipeline is not None:
+            images: List[Image.Image] = self.postprocessing_pipeline(
+                images=images, masks=masks
+            )
+        else:
+            images = list(
+                map(
+                    lambda x: apply_mask(
+                        image=images[x], mask=masks[x], device=self.device
+                    ),
+                    range(len(images)),
+                )
+            )
+        return images
+
+

Subclasses

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/index.html b/docs/api/carvekit/index.html new file mode 100644 index 0000000..e80162d --- /dev/null +++ b/docs/api/carvekit/index.html @@ -0,0 +1,91 @@ + + + + + + +carvekit API documentation + + + + + + + + + + + +
+
+
+

Package carvekit

+
+
+
+ +Expand source code + +
version = "4.5.0"
+
+
+
+

Sub-modules

+
+
carvekit.api
+
+
+
+
carvekit.ml
+
+
+
+
carvekit.pipelines
+
+
+
+
carvekit.trimap
+
+
+
+
carvekit.utils
+
+
+
+
carvekit.web
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/basnet/basnet.html b/docs/api/carvekit/ml/arch/basnet/basnet.html new file mode 100644 index 0000000..50e45d1 --- /dev/null +++ b/docs/api/carvekit/ml/arch/basnet/basnet.html @@ -0,0 +1,1600 @@ + + + + + + +carvekit.ml.arch.basnet.basnet API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.basnet.basnet

+
+
+

Source url: https://github.com/NathanUA/BASNet +Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: MIT License

+
+ +Expand source code + +
"""
+Source url: https://github.com/NathanUA/BASNet
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: MIT License
+"""
+import torch
+import torch.nn as nn
+from torchvision import models
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(
+        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
+    )
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class BasicBlockDe(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlockDe, self).__init__()
+
+        self.convRes = conv3x3(inplanes, planes, stride)
+        self.bnRes = nn.BatchNorm2d(planes)
+        self.reluRes = nn.ReLU(inplace=True)
+
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = self.convRes(x)
+        residual = self.bnRes(residual)
+        residual = self.reluRes(residual)
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(
+            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
+        )
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * 4)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class RefUnet(nn.Module):
+    def __init__(self, in_ch, inc_ch):
+        super(RefUnet, self).__init__()
+
+        self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1)
+
+        self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu1 = nn.ReLU(inplace=True)
+
+        self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn2 = nn.BatchNorm2d(64)
+        self.relu2 = nn.ReLU(inplace=True)
+
+        self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn3 = nn.BatchNorm2d(64)
+        self.relu3 = nn.ReLU(inplace=True)
+
+        self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn4 = nn.BatchNorm2d(64)
+        self.relu4 = nn.ReLU(inplace=True)
+
+        self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        self.conv5 = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn5 = nn.BatchNorm2d(64)
+        self.relu5 = nn.ReLU(inplace=True)
+
+        self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1)
+        self.bn_d4 = nn.BatchNorm2d(64)
+        self.relu_d4 = nn.ReLU(inplace=True)
+
+        self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1)
+        self.bn_d3 = nn.BatchNorm2d(64)
+        self.relu_d3 = nn.ReLU(inplace=True)
+
+        self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1)
+        self.bn_d2 = nn.BatchNorm2d(64)
+        self.relu_d2 = nn.ReLU(inplace=True)
+
+        self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1)
+        self.bn_d1 = nn.BatchNorm2d(64)
+        self.relu_d1 = nn.ReLU(inplace=True)
+
+        self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1)
+
+        self.upscore2 = nn.Upsample(
+            scale_factor=2, mode="bilinear", align_corners=False
+        )
+
+    def forward(self, x):
+        hx = x
+        hx = self.conv0(hx)
+
+        hx1 = self.relu1(self.bn1(self.conv1(hx)))
+        hx = self.pool1(hx1)
+
+        hx2 = self.relu2(self.bn2(self.conv2(hx)))
+        hx = self.pool2(hx2)
+
+        hx3 = self.relu3(self.bn3(self.conv3(hx)))
+        hx = self.pool3(hx3)
+
+        hx4 = self.relu4(self.bn4(self.conv4(hx)))
+        hx = self.pool4(hx4)
+
+        hx5 = self.relu5(self.bn5(self.conv5(hx)))
+
+        hx = self.upscore2(hx5)
+
+        d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx, hx4), 1))))
+        hx = self.upscore2(d4)
+
+        d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx, hx3), 1))))
+        hx = self.upscore2(d3)
+
+        d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx, hx2), 1))))
+        hx = self.upscore2(d2)
+
+        d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx, hx1), 1))))
+
+        residual = self.conv_d0(d1)
+
+        return x + residual
+
+
+class BASNet(nn.Module):
+    def __init__(self, n_channels, n_classes):
+        super(BASNet, self).__init__()
+
+        resnet = models.resnet34(pretrained=False)
+
+        # -------------Encoder--------------
+
+        self.inconv = nn.Conv2d(n_channels, 64, 3, padding=1)
+        self.inbn = nn.BatchNorm2d(64)
+        self.inrelu = nn.ReLU(inplace=True)
+
+        # stage 1
+        self.encoder1 = resnet.layer1  # 224
+        # stage 2
+        self.encoder2 = resnet.layer2  # 112
+        # stage 3
+        self.encoder3 = resnet.layer3  # 56
+        # stage 4
+        self.encoder4 = resnet.layer4  # 28
+
+        self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        # stage 5
+        self.resb5_1 = BasicBlock(512, 512)
+        self.resb5_2 = BasicBlock(512, 512)
+        self.resb5_3 = BasicBlock(512, 512)  # 14
+
+        self.pool5 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        # stage 6
+        self.resb6_1 = BasicBlock(512, 512)
+        self.resb6_2 = BasicBlock(512, 512)
+        self.resb6_3 = BasicBlock(512, 512)  # 7
+
+        # -------------Bridge--------------
+
+        # stage Bridge
+        self.convbg_1 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)  # 7
+        self.bnbg_1 = nn.BatchNorm2d(512)
+        self.relubg_1 = nn.ReLU(inplace=True)
+        self.convbg_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
+        self.bnbg_m = nn.BatchNorm2d(512)
+        self.relubg_m = nn.ReLU(inplace=True)
+        self.convbg_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
+        self.bnbg_2 = nn.BatchNorm2d(512)
+        self.relubg_2 = nn.ReLU(inplace=True)
+
+        # -------------Decoder--------------
+
+        # stage 6d
+        self.conv6d_1 = nn.Conv2d(1024, 512, 3, padding=1)  # 16
+        self.bn6d_1 = nn.BatchNorm2d(512)
+        self.relu6d_1 = nn.ReLU(inplace=True)
+
+        self.conv6d_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
+        self.bn6d_m = nn.BatchNorm2d(512)
+        self.relu6d_m = nn.ReLU(inplace=True)
+
+        self.conv6d_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
+        self.bn6d_2 = nn.BatchNorm2d(512)
+        self.relu6d_2 = nn.ReLU(inplace=True)
+
+        # stage 5d
+        self.conv5d_1 = nn.Conv2d(1024, 512, 3, padding=1)  # 16
+        self.bn5d_1 = nn.BatchNorm2d(512)
+        self.relu5d_1 = nn.ReLU(inplace=True)
+
+        self.conv5d_m = nn.Conv2d(512, 512, 3, padding=1)
+        self.bn5d_m = nn.BatchNorm2d(512)
+        self.relu5d_m = nn.ReLU(inplace=True)
+
+        self.conv5d_2 = nn.Conv2d(512, 512, 3, padding=1)
+        self.bn5d_2 = nn.BatchNorm2d(512)
+        self.relu5d_2 = nn.ReLU(inplace=True)
+
+        # stage 4d
+        self.conv4d_1 = nn.Conv2d(1024, 512, 3, padding=1)  # 32
+        self.bn4d_1 = nn.BatchNorm2d(512)
+        self.relu4d_1 = nn.ReLU(inplace=True)
+
+        self.conv4d_m = nn.Conv2d(512, 512, 3, padding=1)
+        self.bn4d_m = nn.BatchNorm2d(512)
+        self.relu4d_m = nn.ReLU(inplace=True)
+
+        self.conv4d_2 = nn.Conv2d(512, 256, 3, padding=1)
+        self.bn4d_2 = nn.BatchNorm2d(256)
+        self.relu4d_2 = nn.ReLU(inplace=True)
+
+        # stage 3d
+        self.conv3d_1 = nn.Conv2d(512, 256, 3, padding=1)  # 64
+        self.bn3d_1 = nn.BatchNorm2d(256)
+        self.relu3d_1 = nn.ReLU(inplace=True)
+
+        self.conv3d_m = nn.Conv2d(256, 256, 3, padding=1)
+        self.bn3d_m = nn.BatchNorm2d(256)
+        self.relu3d_m = nn.ReLU(inplace=True)
+
+        self.conv3d_2 = nn.Conv2d(256, 128, 3, padding=1)
+        self.bn3d_2 = nn.BatchNorm2d(128)
+        self.relu3d_2 = nn.ReLU(inplace=True)
+
+        # stage 2d
+
+        self.conv2d_1 = nn.Conv2d(256, 128, 3, padding=1)  # 128
+        self.bn2d_1 = nn.BatchNorm2d(128)
+        self.relu2d_1 = nn.ReLU(inplace=True)
+
+        self.conv2d_m = nn.Conv2d(128, 128, 3, padding=1)
+        self.bn2d_m = nn.BatchNorm2d(128)
+        self.relu2d_m = nn.ReLU(inplace=True)
+
+        self.conv2d_2 = nn.Conv2d(128, 64, 3, padding=1)
+        self.bn2d_2 = nn.BatchNorm2d(64)
+        self.relu2d_2 = nn.ReLU(inplace=True)
+
+        # stage 1d
+        self.conv1d_1 = nn.Conv2d(128, 64, 3, padding=1)  # 256
+        self.bn1d_1 = nn.BatchNorm2d(64)
+        self.relu1d_1 = nn.ReLU(inplace=True)
+
+        self.conv1d_m = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn1d_m = nn.BatchNorm2d(64)
+        self.relu1d_m = nn.ReLU(inplace=True)
+
+        self.conv1d_2 = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn1d_2 = nn.BatchNorm2d(64)
+        self.relu1d_2 = nn.ReLU(inplace=True)
+
+        # -------------Bilinear Upsampling--------------
+        self.upscore6 = nn.Upsample(
+            scale_factor=32, mode="bilinear", align_corners=False
+        )
+        self.upscore5 = nn.Upsample(
+            scale_factor=16, mode="bilinear", align_corners=False
+        )
+        self.upscore4 = nn.Upsample(
+            scale_factor=8, mode="bilinear", align_corners=False
+        )
+        self.upscore3 = nn.Upsample(
+            scale_factor=4, mode="bilinear", align_corners=False
+        )
+        self.upscore2 = nn.Upsample(
+            scale_factor=2, mode="bilinear", align_corners=False
+        )
+
+        # -------------Side Output--------------
+        self.outconvb = nn.Conv2d(512, 1, 3, padding=1)
+        self.outconv6 = nn.Conv2d(512, 1, 3, padding=1)
+        self.outconv5 = nn.Conv2d(512, 1, 3, padding=1)
+        self.outconv4 = nn.Conv2d(256, 1, 3, padding=1)
+        self.outconv3 = nn.Conv2d(128, 1, 3, padding=1)
+        self.outconv2 = nn.Conv2d(64, 1, 3, padding=1)
+        self.outconv1 = nn.Conv2d(64, 1, 3, padding=1)
+
+        # -------------Refine Module-------------
+        self.refunet = RefUnet(1, 64)
+
+    def forward(self, x):
+        hx = x
+
+        # -------------Encoder-------------
+        hx = self.inconv(hx)
+        hx = self.inbn(hx)
+        hx = self.inrelu(hx)
+
+        h1 = self.encoder1(hx)  # 256
+        h2 = self.encoder2(h1)  # 128
+        h3 = self.encoder3(h2)  # 64
+        h4 = self.encoder4(h3)  # 32
+
+        hx = self.pool4(h4)  # 16
+
+        hx = self.resb5_1(hx)
+        hx = self.resb5_2(hx)
+        h5 = self.resb5_3(hx)
+
+        hx = self.pool5(h5)  # 8
+
+        hx = self.resb6_1(hx)
+        hx = self.resb6_2(hx)
+        h6 = self.resb6_3(hx)
+
+        # -------------Bridge-------------
+        hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6)))  # 8
+        hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx)))
+        hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx)))
+
+        # -------------Decoder-------------
+
+        hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg, h6), 1))))
+        hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx)))
+        hd6 = self.relu6d_2(self.bn6d_2(self.conv6d_2(hx)))
+
+        hx = self.upscore2(hd6)  # 8 -> 16
+
+        hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx, h5), 1))))
+        hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx)))
+        hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx)))
+
+        hx = self.upscore2(hd5)  # 16 -> 32
+
+        hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx, h4), 1))))
+        hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx)))
+        hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx)))
+
+        hx = self.upscore2(hd4)  # 32 -> 64
+
+        hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx, h3), 1))))
+        hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx)))
+        hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx)))
+
+        hx = self.upscore2(hd3)  # 64 -> 128
+
+        hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx, h2), 1))))
+        hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx)))
+        hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx)))
+
+        hx = self.upscore2(hd2)  # 128 -> 256
+
+        hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx, h1), 1))))
+        hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx)))
+        hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx)))
+
+        # -------------Side Output-------------
+        db = self.outconvb(hbg)
+        db = self.upscore6(db)  # 8->256
+
+        d6 = self.outconv6(hd6)
+        d6 = self.upscore6(d6)  # 8->256
+
+        d5 = self.outconv5(hd5)
+        d5 = self.upscore5(d5)  # 16->256
+
+        d4 = self.outconv4(hd4)
+        d4 = self.upscore4(d4)  # 32->256
+
+        d3 = self.outconv3(hd3)
+        d3 = self.upscore3(d3)  # 64->256
+
+        d2 = self.outconv2(hd2)
+        d2 = self.upscore2(d2)  # 128->256
+
+        d1 = self.outconv1(hd1)  # 256
+
+        # -------------Refine Module-------------
+        dout = self.refunet(d1)  # 256
+
+        return (
+            torch.sigmoid(dout),
+            torch.sigmoid(d1),
+            torch.sigmoid(d2),
+            torch.sigmoid(d3),
+            torch.sigmoid(d4),
+            torch.sigmoid(d5),
+            torch.sigmoid(d6),
+            torch.sigmoid(db),
+        )
+
+
+
+
+
+
+
+

Functions

+
+
+def conv3x3(in_planes, out_planes, stride=1) +
+
+

3x3 convolution with padding

+
+ +Expand source code + +
def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(
+        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
+    )
+
+
+
+
+
+

Classes

+
+
+class BASNet +(n_channels, n_classes) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class BASNet(nn.Module):
+    def __init__(self, n_channels, n_classes):
+        super(BASNet, self).__init__()
+
+        resnet = models.resnet34(pretrained=False)
+
+        # -------------Encoder--------------
+
+        self.inconv = nn.Conv2d(n_channels, 64, 3, padding=1)
+        self.inbn = nn.BatchNorm2d(64)
+        self.inrelu = nn.ReLU(inplace=True)
+
+        # stage 1
+        self.encoder1 = resnet.layer1  # 224
+        # stage 2
+        self.encoder2 = resnet.layer2  # 112
+        # stage 3
+        self.encoder3 = resnet.layer3  # 56
+        # stage 4
+        self.encoder4 = resnet.layer4  # 28
+
+        self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        # stage 5
+        self.resb5_1 = BasicBlock(512, 512)
+        self.resb5_2 = BasicBlock(512, 512)
+        self.resb5_3 = BasicBlock(512, 512)  # 14
+
+        self.pool5 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        # stage 6
+        self.resb6_1 = BasicBlock(512, 512)
+        self.resb6_2 = BasicBlock(512, 512)
+        self.resb6_3 = BasicBlock(512, 512)  # 7
+
+        # -------------Bridge--------------
+
+        # stage Bridge
+        self.convbg_1 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)  # 7
+        self.bnbg_1 = nn.BatchNorm2d(512)
+        self.relubg_1 = nn.ReLU(inplace=True)
+        self.convbg_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
+        self.bnbg_m = nn.BatchNorm2d(512)
+        self.relubg_m = nn.ReLU(inplace=True)
+        self.convbg_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
+        self.bnbg_2 = nn.BatchNorm2d(512)
+        self.relubg_2 = nn.ReLU(inplace=True)
+
+        # -------------Decoder--------------
+
+        # stage 6d
+        self.conv6d_1 = nn.Conv2d(1024, 512, 3, padding=1)  # 16
+        self.bn6d_1 = nn.BatchNorm2d(512)
+        self.relu6d_1 = nn.ReLU(inplace=True)
+
+        self.conv6d_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
+        self.bn6d_m = nn.BatchNorm2d(512)
+        self.relu6d_m = nn.ReLU(inplace=True)
+
+        self.conv6d_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
+        self.bn6d_2 = nn.BatchNorm2d(512)
+        self.relu6d_2 = nn.ReLU(inplace=True)
+
+        # stage 5d
+        self.conv5d_1 = nn.Conv2d(1024, 512, 3, padding=1)  # 16
+        self.bn5d_1 = nn.BatchNorm2d(512)
+        self.relu5d_1 = nn.ReLU(inplace=True)
+
+        self.conv5d_m = nn.Conv2d(512, 512, 3, padding=1)
+        self.bn5d_m = nn.BatchNorm2d(512)
+        self.relu5d_m = nn.ReLU(inplace=True)
+
+        self.conv5d_2 = nn.Conv2d(512, 512, 3, padding=1)
+        self.bn5d_2 = nn.BatchNorm2d(512)
+        self.relu5d_2 = nn.ReLU(inplace=True)
+
+        # stage 4d
+        self.conv4d_1 = nn.Conv2d(1024, 512, 3, padding=1)  # 32
+        self.bn4d_1 = nn.BatchNorm2d(512)
+        self.relu4d_1 = nn.ReLU(inplace=True)
+
+        self.conv4d_m = nn.Conv2d(512, 512, 3, padding=1)
+        self.bn4d_m = nn.BatchNorm2d(512)
+        self.relu4d_m = nn.ReLU(inplace=True)
+
+        self.conv4d_2 = nn.Conv2d(512, 256, 3, padding=1)
+        self.bn4d_2 = nn.BatchNorm2d(256)
+        self.relu4d_2 = nn.ReLU(inplace=True)
+
+        # stage 3d
+        self.conv3d_1 = nn.Conv2d(512, 256, 3, padding=1)  # 64
+        self.bn3d_1 = nn.BatchNorm2d(256)
+        self.relu3d_1 = nn.ReLU(inplace=True)
+
+        self.conv3d_m = nn.Conv2d(256, 256, 3, padding=1)
+        self.bn3d_m = nn.BatchNorm2d(256)
+        self.relu3d_m = nn.ReLU(inplace=True)
+
+        self.conv3d_2 = nn.Conv2d(256, 128, 3, padding=1)
+        self.bn3d_2 = nn.BatchNorm2d(128)
+        self.relu3d_2 = nn.ReLU(inplace=True)
+
+        # stage 2d
+
+        self.conv2d_1 = nn.Conv2d(256, 128, 3, padding=1)  # 128
+        self.bn2d_1 = nn.BatchNorm2d(128)
+        self.relu2d_1 = nn.ReLU(inplace=True)
+
+        self.conv2d_m = nn.Conv2d(128, 128, 3, padding=1)
+        self.bn2d_m = nn.BatchNorm2d(128)
+        self.relu2d_m = nn.ReLU(inplace=True)
+
+        self.conv2d_2 = nn.Conv2d(128, 64, 3, padding=1)
+        self.bn2d_2 = nn.BatchNorm2d(64)
+        self.relu2d_2 = nn.ReLU(inplace=True)
+
+        # stage 1d
+        self.conv1d_1 = nn.Conv2d(128, 64, 3, padding=1)  # 256
+        self.bn1d_1 = nn.BatchNorm2d(64)
+        self.relu1d_1 = nn.ReLU(inplace=True)
+
+        self.conv1d_m = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn1d_m = nn.BatchNorm2d(64)
+        self.relu1d_m = nn.ReLU(inplace=True)
+
+        self.conv1d_2 = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn1d_2 = nn.BatchNorm2d(64)
+        self.relu1d_2 = nn.ReLU(inplace=True)
+
+        # -------------Bilinear Upsampling--------------
+        self.upscore6 = nn.Upsample(
+            scale_factor=32, mode="bilinear", align_corners=False
+        )
+        self.upscore5 = nn.Upsample(
+            scale_factor=16, mode="bilinear", align_corners=False
+        )
+        self.upscore4 = nn.Upsample(
+            scale_factor=8, mode="bilinear", align_corners=False
+        )
+        self.upscore3 = nn.Upsample(
+            scale_factor=4, mode="bilinear", align_corners=False
+        )
+        self.upscore2 = nn.Upsample(
+            scale_factor=2, mode="bilinear", align_corners=False
+        )
+
+        # -------------Side Output--------------
+        self.outconvb = nn.Conv2d(512, 1, 3, padding=1)
+        self.outconv6 = nn.Conv2d(512, 1, 3, padding=1)
+        self.outconv5 = nn.Conv2d(512, 1, 3, padding=1)
+        self.outconv4 = nn.Conv2d(256, 1, 3, padding=1)
+        self.outconv3 = nn.Conv2d(128, 1, 3, padding=1)
+        self.outconv2 = nn.Conv2d(64, 1, 3, padding=1)
+        self.outconv1 = nn.Conv2d(64, 1, 3, padding=1)
+
+        # -------------Refine Module-------------
+        self.refunet = RefUnet(1, 64)
+
+    def forward(self, x):
+        hx = x
+
+        # -------------Encoder-------------
+        hx = self.inconv(hx)
+        hx = self.inbn(hx)
+        hx = self.inrelu(hx)
+
+        h1 = self.encoder1(hx)  # 256
+        h2 = self.encoder2(h1)  # 128
+        h3 = self.encoder3(h2)  # 64
+        h4 = self.encoder4(h3)  # 32
+
+        hx = self.pool4(h4)  # 16
+
+        hx = self.resb5_1(hx)
+        hx = self.resb5_2(hx)
+        h5 = self.resb5_3(hx)
+
+        hx = self.pool5(h5)  # 8
+
+        hx = self.resb6_1(hx)
+        hx = self.resb6_2(hx)
+        h6 = self.resb6_3(hx)
+
+        # -------------Bridge-------------
+        hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6)))  # 8
+        hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx)))
+        hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx)))
+
+        # -------------Decoder-------------
+
+        hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg, h6), 1))))
+        hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx)))
+        hd6 = self.relu6d_2(self.bn6d_2(self.conv6d_2(hx)))
+
+        hx = self.upscore2(hd6)  # 8 -> 16
+
+        hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx, h5), 1))))
+        hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx)))
+        hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx)))
+
+        hx = self.upscore2(hd5)  # 16 -> 32
+
+        hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx, h4), 1))))
+        hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx)))
+        hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx)))
+
+        hx = self.upscore2(hd4)  # 32 -> 64
+
+        hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx, h3), 1))))
+        hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx)))
+        hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx)))
+
+        hx = self.upscore2(hd3)  # 64 -> 128
+
+        hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx, h2), 1))))
+        hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx)))
+        hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx)))
+
+        hx = self.upscore2(hd2)  # 128 -> 256
+
+        hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx, h1), 1))))
+        hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx)))
+        hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx)))
+
+        # -------------Side Output-------------
+        db = self.outconvb(hbg)
+        db = self.upscore6(db)  # 8->256
+
+        d6 = self.outconv6(hd6)
+        d6 = self.upscore6(d6)  # 8->256
+
+        d5 = self.outconv5(hd5)
+        d5 = self.upscore5(d5)  # 16->256
+
+        d4 = self.outconv4(hd4)
+        d4 = self.upscore4(d4)  # 32->256
+
+        d3 = self.outconv3(hd3)
+        d3 = self.upscore3(d3)  # 64->256
+
+        d2 = self.outconv2(hd2)
+        d2 = self.upscore2(d2)  # 128->256
+
+        d1 = self.outconv1(hd1)  # 256
+
+        # -------------Refine Module-------------
+        dout = self.refunet(d1)  # 256
+
+        return (
+            torch.sigmoid(dout),
+            torch.sigmoid(d1),
+            torch.sigmoid(d2),
+            torch.sigmoid(d3),
+            torch.sigmoid(d4),
+            torch.sigmoid(d5),
+            torch.sigmoid(d6),
+            torch.sigmoid(db),
+        )
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    hx = x
+
+    # -------------Encoder-------------
+    hx = self.inconv(hx)
+    hx = self.inbn(hx)
+    hx = self.inrelu(hx)
+
+    h1 = self.encoder1(hx)  # 256
+    h2 = self.encoder2(h1)  # 128
+    h3 = self.encoder3(h2)  # 64
+    h4 = self.encoder4(h3)  # 32
+
+    hx = self.pool4(h4)  # 16
+
+    hx = self.resb5_1(hx)
+    hx = self.resb5_2(hx)
+    h5 = self.resb5_3(hx)
+
+    hx = self.pool5(h5)  # 8
+
+    hx = self.resb6_1(hx)
+    hx = self.resb6_2(hx)
+    h6 = self.resb6_3(hx)
+
+    # -------------Bridge-------------
+    hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6)))  # 8
+    hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx)))
+    hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx)))
+
+    # -------------Decoder-------------
+
+    hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg, h6), 1))))
+    hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx)))
+    hd6 = self.relu6d_2(self.bn6d_2(self.conv6d_2(hx)))
+
+    hx = self.upscore2(hd6)  # 8 -> 16
+
+    hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx, h5), 1))))
+    hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx)))
+    hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx)))
+
+    hx = self.upscore2(hd5)  # 16 -> 32
+
+    hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx, h4), 1))))
+    hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx)))
+    hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx)))
+
+    hx = self.upscore2(hd4)  # 32 -> 64
+
+    hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx, h3), 1))))
+    hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx)))
+    hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx)))
+
+    hx = self.upscore2(hd3)  # 64 -> 128
+
+    hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx, h2), 1))))
+    hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx)))
+    hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx)))
+
+    hx = self.upscore2(hd2)  # 128 -> 256
+
+    hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx, h1), 1))))
+    hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx)))
+    hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx)))
+
+    # -------------Side Output-------------
+    db = self.outconvb(hbg)
+    db = self.upscore6(db)  # 8->256
+
+    d6 = self.outconv6(hd6)
+    d6 = self.upscore6(d6)  # 8->256
+
+    d5 = self.outconv5(hd5)
+    d5 = self.upscore5(d5)  # 16->256
+
+    d4 = self.outconv4(hd4)
+    d4 = self.upscore4(d4)  # 32->256
+
+    d3 = self.outconv3(hd3)
+    d3 = self.upscore3(d3)  # 64->256
+
+    d2 = self.outconv2(hd2)
+    d2 = self.upscore2(d2)  # 128->256
+
+    d1 = self.outconv1(hd1)  # 256
+
+    # -------------Refine Module-------------
+    dout = self.refunet(d1)  # 256
+
+    return (
+        torch.sigmoid(dout),
+        torch.sigmoid(d1),
+        torch.sigmoid(d2),
+        torch.sigmoid(d3),
+        torch.sigmoid(d4),
+        torch.sigmoid(d5),
+        torch.sigmoid(d6),
+        torch.sigmoid(db),
+    )
+
+
+
+
+
+class BasicBlock +(inplanes, planes, stride=1, downsample=None) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var expansion
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    residual = x
+
+    out = self.conv1(x)
+    out = self.bn1(out)
+    out = self.relu(out)
+
+    out = self.conv2(out)
+    out = self.bn2(out)
+
+    if self.downsample is not None:
+        residual = self.downsample(x)
+
+    out += residual
+    out = self.relu(out)
+
+    return out
+
+
+
+
+
+class BasicBlockDe +(inplanes, planes, stride=1, downsample=None) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class BasicBlockDe(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlockDe, self).__init__()
+
+        self.convRes = conv3x3(inplanes, planes, stride)
+        self.bnRes = nn.BatchNorm2d(planes)
+        self.reluRes = nn.ReLU(inplace=True)
+
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = self.convRes(x)
+        residual = self.bnRes(residual)
+        residual = self.reluRes(residual)
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var expansion
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    residual = self.convRes(x)
+    residual = self.bnRes(residual)
+    residual = self.reluRes(residual)
+
+    out = self.conv1(x)
+    out = self.bn1(out)
+    out = self.relu(out)
+
+    out = self.conv2(out)
+    out = self.bn2(out)
+
+    if self.downsample is not None:
+        residual = self.downsample(x)
+
+    out += residual
+    out = self.relu(out)
+
+    return out
+
+
+
+
+
+class Bottleneck +(inplanes, planes, stride=1, downsample=None) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(
+            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
+        )
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * 4)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var expansion
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    residual = x
+
+    out = self.conv1(x)
+    out = self.bn1(out)
+    out = self.relu(out)
+
+    out = self.conv2(out)
+    out = self.bn2(out)
+    out = self.relu(out)
+
+    out = self.conv3(out)
+    out = self.bn3(out)
+
+    if self.downsample is not None:
+        residual = self.downsample(x)
+
+    out += residual
+    out = self.relu(out)
+
+    return out
+
+
+
+
+
+class RefUnet +(in_ch, inc_ch) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class RefUnet(nn.Module):
+    def __init__(self, in_ch, inc_ch):
+        super(RefUnet, self).__init__()
+
+        self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1)
+
+        self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu1 = nn.ReLU(inplace=True)
+
+        self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn2 = nn.BatchNorm2d(64)
+        self.relu2 = nn.ReLU(inplace=True)
+
+        self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn3 = nn.BatchNorm2d(64)
+        self.relu3 = nn.ReLU(inplace=True)
+
+        self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn4 = nn.BatchNorm2d(64)
+        self.relu4 = nn.ReLU(inplace=True)
+
+        self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
+
+        self.conv5 = nn.Conv2d(64, 64, 3, padding=1)
+        self.bn5 = nn.BatchNorm2d(64)
+        self.relu5 = nn.ReLU(inplace=True)
+
+        self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1)
+        self.bn_d4 = nn.BatchNorm2d(64)
+        self.relu_d4 = nn.ReLU(inplace=True)
+
+        self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1)
+        self.bn_d3 = nn.BatchNorm2d(64)
+        self.relu_d3 = nn.ReLU(inplace=True)
+
+        self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1)
+        self.bn_d2 = nn.BatchNorm2d(64)
+        self.relu_d2 = nn.ReLU(inplace=True)
+
+        self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1)
+        self.bn_d1 = nn.BatchNorm2d(64)
+        self.relu_d1 = nn.ReLU(inplace=True)
+
+        self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1)
+
+        self.upscore2 = nn.Upsample(
+            scale_factor=2, mode="bilinear", align_corners=False
+        )
+
+    def forward(self, x):
+        hx = x
+        hx = self.conv0(hx)
+
+        hx1 = self.relu1(self.bn1(self.conv1(hx)))
+        hx = self.pool1(hx1)
+
+        hx2 = self.relu2(self.bn2(self.conv2(hx)))
+        hx = self.pool2(hx2)
+
+        hx3 = self.relu3(self.bn3(self.conv3(hx)))
+        hx = self.pool3(hx3)
+
+        hx4 = self.relu4(self.bn4(self.conv4(hx)))
+        hx = self.pool4(hx4)
+
+        hx5 = self.relu5(self.bn5(self.conv5(hx)))
+
+        hx = self.upscore2(hx5)
+
+        d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx, hx4), 1))))
+        hx = self.upscore2(d4)
+
+        d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx, hx3), 1))))
+        hx = self.upscore2(d3)
+
+        d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx, hx2), 1))))
+        hx = self.upscore2(d2)
+
+        d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx, hx1), 1))))
+
+        residual = self.conv_d0(d1)
+
+        return x + residual
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    hx = x
+    hx = self.conv0(hx)
+
+    hx1 = self.relu1(self.bn1(self.conv1(hx)))
+    hx = self.pool1(hx1)
+
+    hx2 = self.relu2(self.bn2(self.conv2(hx)))
+    hx = self.pool2(hx2)
+
+    hx3 = self.relu3(self.bn3(self.conv3(hx)))
+    hx = self.pool3(hx3)
+
+    hx4 = self.relu4(self.bn4(self.conv4(hx)))
+    hx = self.pool4(hx4)
+
+    hx5 = self.relu5(self.bn5(self.conv5(hx)))
+
+    hx = self.upscore2(hx5)
+
+    d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx, hx4), 1))))
+    hx = self.upscore2(d4)
+
+    d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx, hx3), 1))))
+    hx = self.upscore2(d3)
+
+    d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx, hx2), 1))))
+    hx = self.upscore2(d2)
+
+    d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx, hx1), 1))))
+
+    residual = self.conv_d0(d1)
+
+    return x + residual
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/basnet/index.html b/docs/api/carvekit/ml/arch/basnet/index.html new file mode 100644 index 0000000..4730a60 --- /dev/null +++ b/docs/api/carvekit/ml/arch/basnet/index.html @@ -0,0 +1,67 @@ + + + + + + +carvekit.ml.arch.basnet API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.basnet

+
+
+
+
+

Sub-modules

+
+
carvekit.ml.arch.basnet.basnet
+
+

Source url: https://github.com/NathanUA/BASNet +Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: MIT License

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/cascadepsp/extractors.html b/docs/api/carvekit/ml/arch/cascadepsp/extractors.html new file mode 100644 index 0000000..e9449ac --- /dev/null +++ b/docs/api/carvekit/ml/arch/cascadepsp/extractors.html @@ -0,0 +1,522 @@ + + + + + + +carvekit.ml.arch.cascadepsp.extractors API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.cascadepsp.extractors

+
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/hkchengrex/CascadePSP +License: MIT License

+
+ +Expand source code + +
"""
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+Source url: https://github.com/hkchengrex/CascadePSP
+License: MIT License
+"""
+import math
+
+import torch.nn as nn
+
+
+def conv3x3(in_planes, out_planes, stride=1, dilation=1):
+    return nn.Conv2d(
+        in_planes,
+        out_planes,
+        kernel_size=3,
+        stride=stride,
+        padding=dilation,
+        dilation=dilation,
+        bias=False,
+    )
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(
+            planes,
+            planes,
+            kernel_size=3,
+            stride=stride,
+            dilation=dilation,
+            padding=dilation,
+            bias=False,
+        )
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * 4)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class ResNet(nn.Module):
+    def __init__(self, block, layers=(3, 4, 23, 3)):
+        self.inplanes = 64
+        super(ResNet, self).__init__()
+        self.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2.0 / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(
+                    self.inplanes,
+                    planes * block.expansion,
+                    kernel_size=1,
+                    stride=stride,
+                    bias=False,
+                ),
+                nn.BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = [block(self.inplanes, planes, stride, downsample)]
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes, dilation=dilation))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x_1 = self.conv1(x)  # /2
+        x = self.bn1(x_1)
+        x = self.relu(x)
+        x = self.maxpool(x)  # /2
+
+        x_2 = self.layer1(x)
+        x = self.layer2(x_2)  # /2
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        return x, x_1, x_2
+
+
+def resnet50():
+    model = ResNet(Bottleneck, [3, 4, 6, 3])
+    return model
+
+
+
+
+
+
+
+

Functions

+
+
+def conv3x3(in_planes, out_planes, stride=1, dilation=1) +
+
+
+
+ +Expand source code + +
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
+    return nn.Conv2d(
+        in_planes,
+        out_planes,
+        kernel_size=3,
+        stride=stride,
+        padding=dilation,
+        dilation=dilation,
+        bias=False,
+    )
+
+
+
+def resnet50() +
+
+
+
+ +Expand source code + +
def resnet50():
+    model = ResNet(Bottleneck, [3, 4, 6, 3])
+    return model
+
+
+
+
+
+

Classes

+
+
+class Bottleneck +(inplanes, planes, stride=1, downsample=None, dilation=1) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(
+            planes,
+            planes,
+            kernel_size=3,
+            stride=stride,
+            dilation=dilation,
+            padding=dilation,
+            bias=False,
+        )
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * 4)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var expansion
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    residual = x
+
+    out = self.conv1(x)
+    out = self.bn1(out)
+    out = self.relu(out)
+
+    out = self.conv2(out)
+    out = self.bn2(out)
+    out = self.relu(out)
+
+    out = self.conv3(out)
+    out = self.bn3(out)
+
+    if self.downsample is not None:
+        residual = self.downsample(x)
+
+    out += residual
+    out = self.relu(out)
+
+    return out
+
+
+
+
+
+class ResNet +(block, layers=(3, 4, 23, 3)) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ResNet(nn.Module):
+    def __init__(self, block, layers=(3, 4, 23, 3)):
+        self.inplanes = 64
+        super(ResNet, self).__init__()
+        self.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2.0 / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(
+                    self.inplanes,
+                    planes * block.expansion,
+                    kernel_size=1,
+                    stride=stride,
+                    bias=False,
+                ),
+                nn.BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = [block(self.inplanes, planes, stride, downsample)]
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes, dilation=dilation))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x_1 = self.conv1(x)  # /2
+        x = self.bn1(x_1)
+        x = self.relu(x)
+        x = self.maxpool(x)  # /2
+
+        x_2 = self.layer1(x)
+        x = self.layer2(x_2)  # /2
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        return x, x_1, x_2
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x_1 = self.conv1(x)  # /2
+    x = self.bn1(x_1)
+    x = self.relu(x)
+    x = self.maxpool(x)  # /2
+
+    x_2 = self.layer1(x)
+    x = self.layer2(x_2)  # /2
+    x = self.layer3(x)
+    x = self.layer4(x)
+
+    return x, x_1, x_2
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/cascadepsp/index.html b/docs/api/carvekit/ml/arch/cascadepsp/index.html new file mode 100644 index 0000000..3b6ac43 --- /dev/null +++ b/docs/api/carvekit/ml/arch/cascadepsp/index.html @@ -0,0 +1,79 @@ + + + + + + +carvekit.ml.arch.cascadepsp API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.cascadepsp

+
+
+
+
+

Sub-modules

+
+
carvekit.ml.arch.cascadepsp.extractors
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/hkchengrex/CascadePSP +License: MIT License

+
+
carvekit.ml.arch.cascadepsp.pspnet
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/hkchengrex/CascadePSP +License: MIT License

+
+
carvekit.ml.arch.cascadepsp.utils
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/cascadepsp/pspnet.html b/docs/api/carvekit/ml/arch/cascadepsp/pspnet.html new file mode 100644 index 0000000..9e64ebf --- /dev/null +++ b/docs/api/carvekit/ml/arch/cascadepsp/pspnet.html @@ -0,0 +1,771 @@ + + + + + + +carvekit.ml.arch.cascadepsp.pspnet API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.cascadepsp.pspnet

+
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/hkchengrex/CascadePSP +License: MIT License

+
+ +Expand source code + +
"""
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+Source url: https://github.com/hkchengrex/CascadePSP
+License: MIT License
+"""
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from carvekit.ml.arch.cascadepsp.extractors import resnet50
+
+
+class PSPModule(nn.Module):
+    def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)):
+        super().__init__()
+        self.stages = []
+        self.stages = nn.ModuleList(
+            [self._make_stage(features, size) for size in sizes]
+        )
+        self.bottleneck = nn.Conv2d(
+            features * (len(sizes) + 1), out_features, kernel_size=1
+        )
+        self.relu = nn.ReLU(inplace=True)
+
+    def _make_stage(self, features, size):
+        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
+        conv = nn.Conv2d(features, features, kernel_size=1, bias=False)
+        return nn.Sequential(prior, conv)
+
+    def forward(self, feats):
+        h, w = feats.size(2), feats.size(3)
+        set_priors = [
+            F.interpolate(
+                input=stage(feats), size=(h, w), mode="bilinear", align_corners=False
+            )
+            for stage in self.stages
+        ]
+        priors = set_priors + [feats]
+        bottle = self.bottleneck(torch.cat(priors, 1))
+        return self.relu(bottle)
+
+
+class PSPUpsample(nn.Module):
+    def __init__(self, x_channels, in_channels, out_channels):
+        super().__init__()
+        self.conv = nn.Sequential(
+            nn.BatchNorm2d(in_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels, out_channels, 3, padding=1),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, 3, padding=1),
+        )
+
+        self.conv2 = nn.Sequential(
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, 3, padding=1),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, 3, padding=1),
+        )
+
+        self.shortcut = nn.Conv2d(x_channels, out_channels, kernel_size=1)
+
+    def forward(self, x, up):
+        x = F.interpolate(input=x, scale_factor=2, mode="bilinear", align_corners=False)
+
+        p = self.conv(torch.cat([x, up], 1).type(x.type()))
+        sc = self.shortcut(x)
+
+        p = p + sc
+
+        p2 = self.conv2(p)
+
+        return p + p2
+
+
+class RefinementModule(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+        self.feats = resnet50()
+        self.psp = PSPModule(2048, 1024, (1, 2, 3, 6))
+
+        self.up_1 = PSPUpsample(1024, 1024 + 256, 512)
+        self.up_2 = PSPUpsample(512, 512 + 64, 256)
+        self.up_3 = PSPUpsample(256, 256 + 3, 32)
+
+        self.final_28 = nn.Sequential(
+            nn.Conv2d(1024, 32, kernel_size=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(32, 1, kernel_size=1),
+        )
+
+        self.final_56 = nn.Sequential(
+            nn.Conv2d(512, 32, kernel_size=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(32, 1, kernel_size=1),
+        )
+
+        self.final_11 = nn.Conv2d(32 + 3, 32, kernel_size=1)
+        self.final_21 = nn.Conv2d(32, 1, kernel_size=1)
+
+    def forward(self, x, seg, inter_s8=None, inter_s4=None):
+
+        images = {}
+
+        """
+        First iteration, s8 output
+        """
+        if inter_s8 is None:
+            p = torch.cat((x, seg, seg, seg), 1)
+
+            f, f_1, f_2 = self.feats(p)
+            p = self.psp(f)
+
+            inter_s8 = self.final_28(p)
+            r_inter_s8 = F.interpolate(
+                inter_s8, scale_factor=8, mode="bilinear", align_corners=False
+            )
+            r_inter_tanh_s8 = torch.tanh(r_inter_s8)
+
+            images["pred_28"] = torch.sigmoid(r_inter_s8)
+            images["out_28"] = r_inter_s8
+        else:
+            r_inter_tanh_s8 = inter_s8
+
+        """
+        Second iteration, s8 output
+        """
+        if inter_s4 is None:
+            p = torch.cat((x, seg, r_inter_tanh_s8, r_inter_tanh_s8), 1)
+
+            f, f_1, f_2 = self.feats(p)
+            p = self.psp(f)
+            inter_s8_2 = self.final_28(p)
+            r_inter_s8_2 = F.interpolate(
+                inter_s8_2, scale_factor=8, mode="bilinear", align_corners=False
+            )
+            r_inter_tanh_s8_2 = torch.tanh(r_inter_s8_2)
+
+            p = self.up_1(p, f_2)
+
+            inter_s4 = self.final_56(p)
+            r_inter_s4 = F.interpolate(
+                inter_s4, scale_factor=4, mode="bilinear", align_corners=False
+            )
+            r_inter_tanh_s4 = torch.tanh(r_inter_s4)
+
+            images["pred_28_2"] = torch.sigmoid(r_inter_s8_2)
+            images["out_28_2"] = r_inter_s8_2
+            images["pred_56"] = torch.sigmoid(r_inter_s4)
+            images["out_56"] = r_inter_s4
+        else:
+            r_inter_tanh_s8_2 = inter_s8
+            r_inter_tanh_s4 = inter_s4
+
+        """
+        Third iteration, s1 output
+        """
+        p = torch.cat((x, seg, r_inter_tanh_s8_2, r_inter_tanh_s4), 1)
+
+        f, f_1, f_2 = self.feats(p)
+        p = self.psp(f)
+        inter_s8_3 = self.final_28(p)
+        r_inter_s8_3 = F.interpolate(
+            inter_s8_3, scale_factor=8, mode="bilinear", align_corners=False
+        )
+
+        p = self.up_1(p, f_2)
+        inter_s4_2 = self.final_56(p)
+        r_inter_s4_2 = F.interpolate(
+            inter_s4_2, scale_factor=4, mode="bilinear", align_corners=False
+        )
+        p = self.up_2(p, f_1)
+        p = self.up_3(p, x)
+
+        """
+        Final output
+        """
+        p = F.relu(self.final_11(torch.cat([p, x], 1)), inplace=True)
+        p = self.final_21(p)
+
+        pred_224 = torch.sigmoid(p)
+
+        images["pred_224"] = pred_224
+        images["out_224"] = p
+        images["pred_28_3"] = torch.sigmoid(r_inter_s8_3)
+        images["pred_56_2"] = torch.sigmoid(r_inter_s4_2)
+        images["out_28_3"] = r_inter_s8_3
+        images["out_56_2"] = r_inter_s4_2
+
+        return images
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class PSPModule +(features, out_features=1024, sizes=(1, 2, 3, 6)) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class PSPModule(nn.Module):
+    def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)):
+        super().__init__()
+        self.stages = []
+        self.stages = nn.ModuleList(
+            [self._make_stage(features, size) for size in sizes]
+        )
+        self.bottleneck = nn.Conv2d(
+            features * (len(sizes) + 1), out_features, kernel_size=1
+        )
+        self.relu = nn.ReLU(inplace=True)
+
+    def _make_stage(self, features, size):
+        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
+        conv = nn.Conv2d(features, features, kernel_size=1, bias=False)
+        return nn.Sequential(prior, conv)
+
+    def forward(self, feats):
+        h, w = feats.size(2), feats.size(3)
+        set_priors = [
+            F.interpolate(
+                input=stage(feats), size=(h, w), mode="bilinear", align_corners=False
+            )
+            for stage in self.stages
+        ]
+        priors = set_priors + [feats]
+        bottle = self.bottleneck(torch.cat(priors, 1))
+        return self.relu(bottle)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, feats) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, feats):
+    h, w = feats.size(2), feats.size(3)
+    set_priors = [
+        F.interpolate(
+            input=stage(feats), size=(h, w), mode="bilinear", align_corners=False
+        )
+        for stage in self.stages
+    ]
+    priors = set_priors + [feats]
+    bottle = self.bottleneck(torch.cat(priors, 1))
+    return self.relu(bottle)
+
+
+
+
+
+class PSPUpsample +(x_channels, in_channels, out_channels) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class PSPUpsample(nn.Module):
+    def __init__(self, x_channels, in_channels, out_channels):
+        super().__init__()
+        self.conv = nn.Sequential(
+            nn.BatchNorm2d(in_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels, out_channels, 3, padding=1),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, 3, padding=1),
+        )
+
+        self.conv2 = nn.Sequential(
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, 3, padding=1),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, 3, padding=1),
+        )
+
+        self.shortcut = nn.Conv2d(x_channels, out_channels, kernel_size=1)
+
+    def forward(self, x, up):
+        x = F.interpolate(input=x, scale_factor=2, mode="bilinear", align_corners=False)
+
+        p = self.conv(torch.cat([x, up], 1).type(x.type()))
+        sc = self.shortcut(x)
+
+        p = p + sc
+
+        p2 = self.conv2(p)
+
+        return p + p2
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x, up) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x, up):
+    x = F.interpolate(input=x, scale_factor=2, mode="bilinear", align_corners=False)
+
+    p = self.conv(torch.cat([x, up], 1).type(x.type()))
+    sc = self.shortcut(x)
+
+    p = p + sc
+
+    p2 = self.conv2(p)
+
+    return p + p2
+
+
+
+
+
+class RefinementModule +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class RefinementModule(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+        self.feats = resnet50()
+        self.psp = PSPModule(2048, 1024, (1, 2, 3, 6))
+
+        self.up_1 = PSPUpsample(1024, 1024 + 256, 512)
+        self.up_2 = PSPUpsample(512, 512 + 64, 256)
+        self.up_3 = PSPUpsample(256, 256 + 3, 32)
+
+        self.final_28 = nn.Sequential(
+            nn.Conv2d(1024, 32, kernel_size=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(32, 1, kernel_size=1),
+        )
+
+        self.final_56 = nn.Sequential(
+            nn.Conv2d(512, 32, kernel_size=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(32, 1, kernel_size=1),
+        )
+
+        self.final_11 = nn.Conv2d(32 + 3, 32, kernel_size=1)
+        self.final_21 = nn.Conv2d(32, 1, kernel_size=1)
+
+    def forward(self, x, seg, inter_s8=None, inter_s4=None):
+
+        images = {}
+
+        """
+        First iteration, s8 output
+        """
+        if inter_s8 is None:
+            p = torch.cat((x, seg, seg, seg), 1)
+
+            f, f_1, f_2 = self.feats(p)
+            p = self.psp(f)
+
+            inter_s8 = self.final_28(p)
+            r_inter_s8 = F.interpolate(
+                inter_s8, scale_factor=8, mode="bilinear", align_corners=False
+            )
+            r_inter_tanh_s8 = torch.tanh(r_inter_s8)
+
+            images["pred_28"] = torch.sigmoid(r_inter_s8)
+            images["out_28"] = r_inter_s8
+        else:
+            r_inter_tanh_s8 = inter_s8
+
+        """
+        Second iteration, s8 output
+        """
+        if inter_s4 is None:
+            p = torch.cat((x, seg, r_inter_tanh_s8, r_inter_tanh_s8), 1)
+
+            f, f_1, f_2 = self.feats(p)
+            p = self.psp(f)
+            inter_s8_2 = self.final_28(p)
+            r_inter_s8_2 = F.interpolate(
+                inter_s8_2, scale_factor=8, mode="bilinear", align_corners=False
+            )
+            r_inter_tanh_s8_2 = torch.tanh(r_inter_s8_2)
+
+            p = self.up_1(p, f_2)
+
+            inter_s4 = self.final_56(p)
+            r_inter_s4 = F.interpolate(
+                inter_s4, scale_factor=4, mode="bilinear", align_corners=False
+            )
+            r_inter_tanh_s4 = torch.tanh(r_inter_s4)
+
+            images["pred_28_2"] = torch.sigmoid(r_inter_s8_2)
+            images["out_28_2"] = r_inter_s8_2
+            images["pred_56"] = torch.sigmoid(r_inter_s4)
+            images["out_56"] = r_inter_s4
+        else:
+            r_inter_tanh_s8_2 = inter_s8
+            r_inter_tanh_s4 = inter_s4
+
+        """
+        Third iteration, s1 output
+        """
+        p = torch.cat((x, seg, r_inter_tanh_s8_2, r_inter_tanh_s4), 1)
+
+        f, f_1, f_2 = self.feats(p)
+        p = self.psp(f)
+        inter_s8_3 = self.final_28(p)
+        r_inter_s8_3 = F.interpolate(
+            inter_s8_3, scale_factor=8, mode="bilinear", align_corners=False
+        )
+
+        p = self.up_1(p, f_2)
+        inter_s4_2 = self.final_56(p)
+        r_inter_s4_2 = F.interpolate(
+            inter_s4_2, scale_factor=4, mode="bilinear", align_corners=False
+        )
+        p = self.up_2(p, f_1)
+        p = self.up_3(p, x)
+
+        """
+        Final output
+        """
+        p = F.relu(self.final_11(torch.cat([p, x], 1)), inplace=True)
+        p = self.final_21(p)
+
+        pred_224 = torch.sigmoid(p)
+
+        images["pred_224"] = pred_224
+        images["out_224"] = p
+        images["pred_28_3"] = torch.sigmoid(r_inter_s8_3)
+        images["pred_56_2"] = torch.sigmoid(r_inter_s4_2)
+        images["out_28_3"] = r_inter_s8_3
+        images["out_56_2"] = r_inter_s4_2
+
+        return images
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Methods

+
+
+def forward(self, x, seg, inter_s8=None, inter_s4=None) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x, seg, inter_s8=None, inter_s4=None):
+
+    images = {}
+
+    """
+    First iteration, s8 output
+    """
+    if inter_s8 is None:
+        p = torch.cat((x, seg, seg, seg), 1)
+
+        f, f_1, f_2 = self.feats(p)
+        p = self.psp(f)
+
+        inter_s8 = self.final_28(p)
+        r_inter_s8 = F.interpolate(
+            inter_s8, scale_factor=8, mode="bilinear", align_corners=False
+        )
+        r_inter_tanh_s8 = torch.tanh(r_inter_s8)
+
+        images["pred_28"] = torch.sigmoid(r_inter_s8)
+        images["out_28"] = r_inter_s8
+    else:
+        r_inter_tanh_s8 = inter_s8
+
+    """
+    Second iteration, s8 output
+    """
+    if inter_s4 is None:
+        p = torch.cat((x, seg, r_inter_tanh_s8, r_inter_tanh_s8), 1)
+
+        f, f_1, f_2 = self.feats(p)
+        p = self.psp(f)
+        inter_s8_2 = self.final_28(p)
+        r_inter_s8_2 = F.interpolate(
+            inter_s8_2, scale_factor=8, mode="bilinear", align_corners=False
+        )
+        r_inter_tanh_s8_2 = torch.tanh(r_inter_s8_2)
+
+        p = self.up_1(p, f_2)
+
+        inter_s4 = self.final_56(p)
+        r_inter_s4 = F.interpolate(
+            inter_s4, scale_factor=4, mode="bilinear", align_corners=False
+        )
+        r_inter_tanh_s4 = torch.tanh(r_inter_s4)
+
+        images["pred_28_2"] = torch.sigmoid(r_inter_s8_2)
+        images["out_28_2"] = r_inter_s8_2
+        images["pred_56"] = torch.sigmoid(r_inter_s4)
+        images["out_56"] = r_inter_s4
+    else:
+        r_inter_tanh_s8_2 = inter_s8
+        r_inter_tanh_s4 = inter_s4
+
+    """
+    Third iteration, s1 output
+    """
+    p = torch.cat((x, seg, r_inter_tanh_s8_2, r_inter_tanh_s4), 1)
+
+    f, f_1, f_2 = self.feats(p)
+    p = self.psp(f)
+    inter_s8_3 = self.final_28(p)
+    r_inter_s8_3 = F.interpolate(
+        inter_s8_3, scale_factor=8, mode="bilinear", align_corners=False
+    )
+
+    p = self.up_1(p, f_2)
+    inter_s4_2 = self.final_56(p)
+    r_inter_s4_2 = F.interpolate(
+        inter_s4_2, scale_factor=4, mode="bilinear", align_corners=False
+    )
+    p = self.up_2(p, f_1)
+    p = self.up_3(p, x)
+
+    """
+    Final output
+    """
+    p = F.relu(self.final_11(torch.cat([p, x], 1)), inplace=True)
+    p = self.final_21(p)
+
+    pred_224 = torch.sigmoid(p)
+
+    images["pred_224"] = pred_224
+    images["out_224"] = p
+    images["pred_28_3"] = torch.sigmoid(r_inter_s8_3)
+    images["pred_56_2"] = torch.sigmoid(r_inter_s4_2)
+    images["out_28_3"] = r_inter_s8_3
+    images["out_56_2"] = r_inter_s4_2
+
+    return images
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/cascadepsp/utils.html b/docs/api/carvekit/ml/arch/cascadepsp/utils.html new file mode 100644 index 0000000..a59b80d --- /dev/null +++ b/docs/api/carvekit/ml/arch/cascadepsp/utils.html @@ -0,0 +1,425 @@ + + + + + + +carvekit.ml.arch.cascadepsp.utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.cascadepsp.utils

+
+
+
+ +Expand source code + +
import torch
+import torch.nn.functional as F
+
+
+def resize_max_side(im, size, method):
+    h, w = im.shape[-2:]
+    max_side = max(h, w)
+    ratio = size / max_side
+    if method in ["bilinear", "bicubic"]:
+        return F.interpolate(im, scale_factor=ratio, mode=method, align_corners=False)
+    else:
+        return F.interpolate(im, scale_factor=ratio, mode=method)
+
+
+def process_high_res_im(model, im, seg, L=900):
+    stride = L // 2
+
+    _, _, h, w = seg.shape
+    if max(h, w) > L:
+        im_small = resize_max_side(im, L, "area")
+        seg_small = resize_max_side(seg, L, "area")
+    elif max(h, w) < L:
+        im_small = resize_max_side(im, L, "bicubic")
+        seg_small = resize_max_side(seg, L, "bilinear")
+    else:
+        im_small = im
+        seg_small = seg
+
+    images = model.safe_forward(im_small, seg_small)
+
+    pred_224 = images["pred_224"]
+    pred_56 = images["pred_56_2"]
+
+    for new_size in [max(h, w)]:
+        im_small = resize_max_side(im, new_size, "area")
+        seg_small = resize_max_side(seg, new_size, "area")
+        _, _, h, w = seg_small.shape
+
+        combined_224 = torch.zeros_like(seg_small)
+        combined_weight = torch.zeros_like(seg_small)
+
+        r_pred_224 = (
+            F.interpolate(pred_224, size=(h, w), mode="bilinear", align_corners=False)
+            > 0.5
+        ).float() * 2 - 1
+        r_pred_56 = (
+            F.interpolate(pred_56, size=(h, w), mode="bilinear", align_corners=False)
+            * 2
+            - 1
+        )
+
+        padding = 16
+        step_size = stride - padding * 2
+        step_len = L
+
+        used_start_idx = {}
+        for x_idx in range((w) // step_size + 1):
+            for y_idx in range((h) // step_size + 1):
+
+                start_x = x_idx * step_size
+                start_y = y_idx * step_size
+                end_x = start_x + step_len
+                end_y = start_y + step_len
+
+                # Shift when required
+                if end_y > h:
+                    end_y = h
+                    start_y = h - step_len
+                if end_x > w:
+                    end_x = w
+                    start_x = w - step_len
+
+                # Bound x/y range
+                start_x = max(0, start_x)
+                start_y = max(0, start_y)
+                end_x = min(w, end_x)
+                end_y = min(h, end_y)
+
+                # The same crop might appear twice due to bounding/shifting
+                start_idx = start_y * w + start_x
+                if start_idx in used_start_idx:
+                    continue
+                else:
+                    used_start_idx[start_idx] = True
+
+                # Take crop
+                im_part = im_small[:, :, start_y:end_y, start_x:end_x]
+                seg_224_part = r_pred_224[:, :, start_y:end_y, start_x:end_x]
+                seg_56_part = r_pred_56[:, :, start_y:end_y, start_x:end_x]
+
+                # Skip when it is not an interesting crop anyway
+                seg_part_norm = (seg_224_part > 0).float()
+                high_thres = 0.9
+                low_thres = 0.1
+                if (seg_part_norm.mean() > high_thres) or (
+                    seg_part_norm.mean() < low_thres
+                ):
+                    continue
+                grid_images = model.safe_forward(im_part, seg_224_part, seg_56_part)
+                grid_pred_224 = grid_images["pred_224"]
+
+                # Padding
+                pred_sx = pred_sy = 0
+                pred_ex = step_len
+                pred_ey = step_len
+
+                if start_x != 0:
+                    start_x += padding
+                    pred_sx += padding
+                if start_y != 0:
+                    start_y += padding
+                    pred_sy += padding
+                if end_x != w:
+                    end_x -= padding
+                    pred_ex -= padding
+                if end_y != h:
+                    end_y -= padding
+                    pred_ey -= padding
+
+                combined_224[:, :, start_y:end_y, start_x:end_x] += grid_pred_224[
+                    :, :, pred_sy:pred_ey, pred_sx:pred_ex
+                ]
+
+                del grid_pred_224
+
+                # Used for averaging
+                combined_weight[:, :, start_y:end_y, start_x:end_x] += 1
+
+        # Final full resolution output
+        seg_norm = r_pred_224 / 2 + 0.5
+        pred_224 = combined_224 / combined_weight
+        pred_224 = torch.where(combined_weight == 0, seg_norm, pred_224)
+
+    _, _, h, w = seg.shape
+    images = {}
+    images["pred_224"] = F.interpolate(
+        pred_224, size=(h, w), mode="bilinear", align_corners=True
+    )
+
+    return images["pred_224"]
+
+
+def process_im_single_pass(model, im, seg, L=900):
+    """
+    A single pass version, aka global step only.
+    """
+
+    _, _, h, w = im.shape
+    if max(h, w) < L:
+        im = resize_max_side(im, L, "bicubic")
+        seg = resize_max_side(seg, L, "bilinear")
+
+    if max(h, w) > L:
+        im = resize_max_side(im, L, "area")
+        seg = resize_max_side(seg, L, "area")
+
+    images = model.safe_forward(im, seg)
+
+    if max(h, w) < L:
+        images["pred_224"] = F.interpolate(images["pred_224"], size=(h, w), mode="area")
+    elif max(h, w) > L:
+        images["pred_224"] = F.interpolate(
+            images["pred_224"], size=(h, w), mode="bilinear", align_corners=True
+        )
+
+    return images["pred_224"]
+
+
+
+
+
+
+
+

Functions

+
+
+def process_high_res_im(model, im, seg, L=900) +
+
+
+
+ +Expand source code + +
def process_high_res_im(model, im, seg, L=900):
+    stride = L // 2
+
+    _, _, h, w = seg.shape
+    if max(h, w) > L:
+        im_small = resize_max_side(im, L, "area")
+        seg_small = resize_max_side(seg, L, "area")
+    elif max(h, w) < L:
+        im_small = resize_max_side(im, L, "bicubic")
+        seg_small = resize_max_side(seg, L, "bilinear")
+    else:
+        im_small = im
+        seg_small = seg
+
+    images = model.safe_forward(im_small, seg_small)
+
+    pred_224 = images["pred_224"]
+    pred_56 = images["pred_56_2"]
+
+    for new_size in [max(h, w)]:
+        im_small = resize_max_side(im, new_size, "area")
+        seg_small = resize_max_side(seg, new_size, "area")
+        _, _, h, w = seg_small.shape
+
+        combined_224 = torch.zeros_like(seg_small)
+        combined_weight = torch.zeros_like(seg_small)
+
+        r_pred_224 = (
+            F.interpolate(pred_224, size=(h, w), mode="bilinear", align_corners=False)
+            > 0.5
+        ).float() * 2 - 1
+        r_pred_56 = (
+            F.interpolate(pred_56, size=(h, w), mode="bilinear", align_corners=False)
+            * 2
+            - 1
+        )
+
+        padding = 16
+        step_size = stride - padding * 2
+        step_len = L
+
+        used_start_idx = {}
+        for x_idx in range((w) // step_size + 1):
+            for y_idx in range((h) // step_size + 1):
+
+                start_x = x_idx * step_size
+                start_y = y_idx * step_size
+                end_x = start_x + step_len
+                end_y = start_y + step_len
+
+                # Shift when required
+                if end_y > h:
+                    end_y = h
+                    start_y = h - step_len
+                if end_x > w:
+                    end_x = w
+                    start_x = w - step_len
+
+                # Bound x/y range
+                start_x = max(0, start_x)
+                start_y = max(0, start_y)
+                end_x = min(w, end_x)
+                end_y = min(h, end_y)
+
+                # The same crop might appear twice due to bounding/shifting
+                start_idx = start_y * w + start_x
+                if start_idx in used_start_idx:
+                    continue
+                else:
+                    used_start_idx[start_idx] = True
+
+                # Take crop
+                im_part = im_small[:, :, start_y:end_y, start_x:end_x]
+                seg_224_part = r_pred_224[:, :, start_y:end_y, start_x:end_x]
+                seg_56_part = r_pred_56[:, :, start_y:end_y, start_x:end_x]
+
+                # Skip when it is not an interesting crop anyway
+                seg_part_norm = (seg_224_part > 0).float()
+                high_thres = 0.9
+                low_thres = 0.1
+                if (seg_part_norm.mean() > high_thres) or (
+                    seg_part_norm.mean() < low_thres
+                ):
+                    continue
+                grid_images = model.safe_forward(im_part, seg_224_part, seg_56_part)
+                grid_pred_224 = grid_images["pred_224"]
+
+                # Padding
+                pred_sx = pred_sy = 0
+                pred_ex = step_len
+                pred_ey = step_len
+
+                if start_x != 0:
+                    start_x += padding
+                    pred_sx += padding
+                if start_y != 0:
+                    start_y += padding
+                    pred_sy += padding
+                if end_x != w:
+                    end_x -= padding
+                    pred_ex -= padding
+                if end_y != h:
+                    end_y -= padding
+                    pred_ey -= padding
+
+                combined_224[:, :, start_y:end_y, start_x:end_x] += grid_pred_224[
+                    :, :, pred_sy:pred_ey, pred_sx:pred_ex
+                ]
+
+                del grid_pred_224
+
+                # Used for averaging
+                combined_weight[:, :, start_y:end_y, start_x:end_x] += 1
+
+        # Final full resolution output
+        seg_norm = r_pred_224 / 2 + 0.5
+        pred_224 = combined_224 / combined_weight
+        pred_224 = torch.where(combined_weight == 0, seg_norm, pred_224)
+
+    _, _, h, w = seg.shape
+    images = {}
+    images["pred_224"] = F.interpolate(
+        pred_224, size=(h, w), mode="bilinear", align_corners=True
+    )
+
+    return images["pred_224"]
+
+
+
+def process_im_single_pass(model, im, seg, L=900) +
+
+

A single pass version, aka global step only.

+
+ +Expand source code + +
def process_im_single_pass(model, im, seg, L=900):
+    """
+    A single pass version, aka global step only.
+    """
+
+    _, _, h, w = im.shape
+    if max(h, w) < L:
+        im = resize_max_side(im, L, "bicubic")
+        seg = resize_max_side(seg, L, "bilinear")
+
+    if max(h, w) > L:
+        im = resize_max_side(im, L, "area")
+        seg = resize_max_side(seg, L, "area")
+
+    images = model.safe_forward(im, seg)
+
+    if max(h, w) < L:
+        images["pred_224"] = F.interpolate(images["pred_224"], size=(h, w), mode="area")
+    elif max(h, w) > L:
+        images["pred_224"] = F.interpolate(
+            images["pred_224"], size=(h, w), mode="bilinear", align_corners=True
+        )
+
+    return images["pred_224"]
+
+
+
+def resize_max_side(im, size, method) +
+
+
+
+ +Expand source code + +
def resize_max_side(im, size, method):
+    h, w = im.shape[-2:]
+    max_side = max(h, w)
+    ratio = size / max_side
+    if method in ["bilinear", "bicubic"]:
+        return F.interpolate(im, scale_factor=ratio, mode=method, align_corners=False)
+    else:
+        return F.interpolate(im, scale_factor=ratio, mode=method)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/fba_matting/index.html b/docs/api/carvekit/ml/arch/fba_matting/index.html new file mode 100644 index 0000000..5a059ae --- /dev/null +++ b/docs/api/carvekit/ml/arch/fba_matting/index.html @@ -0,0 +1,95 @@ + + + + + + +carvekit.ml.arch.fba_matting API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.fba_matting

+
+
+
+
+

Sub-modules

+
+
carvekit.ml.arch.fba_matting.layers_WS
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/MarcoForte/FBA_Matting +License: MIT License

+
+
carvekit.ml.arch.fba_matting.models
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/MarcoForte/FBA_Matting +License: MIT License

+
+
carvekit.ml.arch.fba_matting.resnet_GN_WS
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/MarcoForte/FBA_Matting +License: MIT License

+
+
carvekit.ml.arch.fba_matting.resnet_bn
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/MarcoForte/FBA_Matting +License: MIT License

+
+
carvekit.ml.arch.fba_matting.transforms
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/MarcoForte/FBA_Matting +License: MIT License

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/fba_matting/layers_WS.html b/docs/api/carvekit/ml/arch/fba_matting/layers_WS.html new file mode 100644 index 0000000..1d0f5f9 --- /dev/null +++ b/docs/api/carvekit/ml/arch/fba_matting/layers_WS.html @@ -0,0 +1,383 @@ + + + + + + +carvekit.ml.arch.fba_matting.layers_WS API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.fba_matting.layers_WS

+
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/MarcoForte/FBA_Matting +License: MIT License

+
+ +Expand source code + +
"""
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+Source url: https://github.com/MarcoForte/FBA_Matting
+License: MIT License
+"""
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class Conv2d(nn.Conv2d):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        bias=True,
+    ):
+        super(Conv2d, self).__init__(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride,
+            padding,
+            dilation,
+            groups,
+            bias,
+        )
+
+    def forward(self, x):
+        # return super(Conv2d, self).forward(x)
+        weight = self.weight
+        weight_mean = (
+            weight.mean(dim=1, keepdim=True)
+            .mean(dim=2, keepdim=True)
+            .mean(dim=3, keepdim=True)
+        )
+        weight = weight - weight_mean
+        # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
+        std = (
+            torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(
+                -1, 1, 1, 1
+            )
+            + 1e-5
+        )
+        weight = weight / std.expand_as(weight)
+        return F.conv2d(
+            x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
+        )
+
+
+def BatchNorm2d(num_features):
+    return nn.GroupNorm(num_channels=num_features, num_groups=32)
+
+
+
+
+
+
+
+

Functions

+
+
+def BatchNorm2d(num_features) +
+
+
+
+ +Expand source code + +
def BatchNorm2d(num_features):
+    return nn.GroupNorm(num_channels=num_features, num_groups=32)
+
+
+
+
+
+

Classes

+
+
+class Conv2d +(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) +
+
+

Applies a 2D convolution over an input signal composed of several input +planes.

+

In the simplest case, the output value of the layer with input size +:math:(N, C_{\text{in}}, H, W) and output :math:(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}}) +can be precisely described as:

+

[ \text{out}(N_i, C_{\text{out}j}) = \text{bias}(Cj}) + +\sum^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}j}, k) \star \text{input}(N_i, k) ] +where :math:\star is the valid 2D cross-correlation operator, +:math:N is a batch size, :math:C denotes a number of channels, +:math:H is a height of input planes in pixels, and :math:W is +width in pixels.

+

This module supports :ref:TensorFloat32<tf32_on_ampere>.

+
    +
  • +

    :attr:stride controls the stride for the cross-correlation, a single +number or a tuple.

    +
  • +
  • +

    :attr:padding controls the amount of padding applied to the input. It +can be either a string {'valid', 'same'} or a tuple of ints giving the +amount of implicit padding applied on both sides.

    +
  • +
  • +

    :attr:dilation controls the spacing between the kernel points; also +known as the Γ  trous algorithm. It is harder to describe, but this link_ +has a nice visualization of what :attr:dilation does.

    +
  • +
  • +

    :attr:groups controls the connections between inputs and outputs. +:attr:in_channels and :attr:out_channels must both be divisible by +:attr:groups. For example,

    +
      +
    • At groups=1, all inputs are convolved to all outputs.
    • +
    • At groups=2, the operation becomes equivalent to having two conv +layers side by side, each seeing half the input channels +and producing half the output channels, and both subsequently +concatenated.
    • +
    • At groups= :attr:in_channels, each input channel is convolved with +its own set of filters (of size +:math:\frac{\text{out\_channels}}{\text{in\_channels}}).
    • +
    +
  • +
+

The parameters :attr:kernel_size, :attr:stride, :attr:padding, :attr:dilation can either be:

+
- a single <code>int</code> -- in which case the same value is used for the height and width dimension
+- a <code>tuple</code> of two ints -- in which case, the first <code>int</code> is used for the height dimension,
+  and the second <code>int</code> for the width dimension
+
+

Note

+

When groups == in_channels and out_channels == K * in_channels, +where K is a positive integer, this operation is also known as a "depthwise convolution".

+

In other words, for an input of size :math:(N, C_{in}, L_{in}), +a depthwise convolution with a depthwise multiplier K can be performed with the arguments +:math:(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in}).

+

Note

+

In some circumstances when given tensors on a CUDA device and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting torch.backends.cudnn.deterministic = True. See :doc:/notes/randomness for more information.

+

Note

+

padding='valid' is the same as no padding. padding='same' pads +the input so the output has the shape as the input. However, this mode +doesn't support any stride values other than 1.

+

Args

+
+
in_channels : int
+
Number of channels in the input image
+
out_channels : int
+
Number of channels produced by the convolution
+
kernel_size : int or tuple
+
Size of the convolving kernel
+
stride : int or tuple, optional
+
Stride of the convolution. Default: 1
+
padding : int, tuple or str, optional
+
Padding added to all four sides of +the input. Default: 0
+
padding_mode : string, optional
+
'zeros', 'reflect', +'replicate' or 'circular'. Default: 'zeros'
+
dilation : int or tuple, optional
+
Spacing between kernel elements. Default: 1
+
groups : int, optional
+
Number of blocked connections from input +channels to output channels. Default: 1
+
bias : bool, optional
+
If True, adds a learnable bias to the +output. Default: True
+
+

Shape

+
    +
  • Input: :math:(N, C_{in}, H_{in}, W_{in}) or :math:(C_{in}, H_{in}, W_{in})
  • +
  • Output: :math:(N, C_{out}, H_{out}, W_{out}) or :math:(C_{out}, H_{out}, W_{out}), where
  • +
+

[ H_{out} = \left\lfloor\frac{H_{in} ++ 2 \times \text{padding}[0] - \text{dilation}[0] +\times (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor ] +[ W_{out} = \left\lfloor\frac{W_{in} ++ 2 \times \text{padding}[1] - \text{dilation}[1] +\times (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor ]

+

Attributes

+
+
weight : Tensor
+
the learnable weights of the module of shape +:math:(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}}, +:math:\text{kernel\_size[0]}, \text{kernel\_size[1]}). +The values of these weights are sampled from +:math:\mathcal{U}(-\sqrt{k}, \sqrt{k}) where +:math:k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}
+
bias : Tensor
+
the learnable bias of the module of shape +(out_channels). If :attr:bias is True, +then the values of these weights are +sampled from :math:\mathcal{U}(-\sqrt{k}, \sqrt{k}) where +:math:k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}
+
+

Examples

+
>>> # With square kernels and equal stride
+>>> m = nn.Conv2d(16, 33, 3, stride=2)
+>>> # non-square kernels and unequal stride and with padding
+>>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
+>>> # non-square kernels and unequal stride and with padding and dilation
+>>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
+>>> input = torch.randn(20, 16, 50, 100)
+>>> output = m(input)
+
+

.. _cross-correlation: +https://en.wikipedia.org/wiki/Cross-correlation

+

.. _link: +https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Conv2d(nn.Conv2d):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        bias=True,
+    ):
+        super(Conv2d, self).__init__(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride,
+            padding,
+            dilation,
+            groups,
+            bias,
+        )
+
+    def forward(self, x):
+        # return super(Conv2d, self).forward(x)
+        weight = self.weight
+        weight_mean = (
+            weight.mean(dim=1, keepdim=True)
+            .mean(dim=2, keepdim=True)
+            .mean(dim=3, keepdim=True)
+        )
+        weight = weight - weight_mean
+        # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
+        std = (
+            torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(
+                -1, 1, 1, 1
+            )
+            + 1e-5
+        )
+        weight = weight / std.expand_as(weight)
+        return F.conv2d(
+            x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
+        )
+
+

Ancestors

+
    +
  • torch.nn.modules.conv.Conv2d
  • +
  • torch.nn.modules.conv._ConvNd
  • +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    # return super(Conv2d, self).forward(x)
+    weight = self.weight
+    weight_mean = (
+        weight.mean(dim=1, keepdim=True)
+        .mean(dim=2, keepdim=True)
+        .mean(dim=3, keepdim=True)
+    )
+    weight = weight - weight_mean
+    # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
+    std = (
+        torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(
+            -1, 1, 1, 1
+        )
+        + 1e-5
+    )
+    weight = weight / std.expand_as(weight)
+    return F.conv2d(
+        x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
+    )
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/fba_matting/models.html b/docs/api/carvekit/ml/arch/fba_matting/models.html new file mode 100644 index 0000000..c473494 --- /dev/null +++ b/docs/api/carvekit/ml/arch/fba_matting/models.html @@ -0,0 +1,1236 @@ + + + + + + +carvekit.ml.arch.fba_matting.models API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.fba_matting.models

+
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/MarcoForte/FBA_Matting +License: MIT License

+
+ +Expand source code + +
"""
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+Source url: https://github.com/MarcoForte/FBA_Matting
+License: MIT License
+"""
+import torch
+import torch.nn as nn
+import carvekit.ml.arch.fba_matting.resnet_GN_WS as resnet_GN_WS
+import carvekit.ml.arch.fba_matting.layers_WS as L
+import carvekit.ml.arch.fba_matting.resnet_bn as resnet_bn
+from functools import partial
+
+
+class FBA(nn.Module):
+    def __init__(self, encoder: str):
+        super(FBA, self).__init__()
+        self.encoder = build_encoder(arch=encoder)
+        self.decoder = fba_decoder(batch_norm=True if "BN" in encoder else False)
+
+    def forward(self, image, two_chan_trimap, image_n, trimap_transformed):
+        resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
+        conv_out, indices = self.encoder(resnet_input, return_feature_maps=True)
+        return self.decoder(conv_out, image, indices, two_chan_trimap)
+
+
+class ResnetDilatedBN(nn.Module):
+    def __init__(self, orig_resnet, dilate_scale=8):
+        super(ResnetDilatedBN, self).__init__()
+
+        if dilate_scale == 8:
+            orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
+            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
+        elif dilate_scale == 16:
+            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
+
+        # take pretrained resnet, except AvgPool and FC
+        self.conv1 = orig_resnet.conv1
+        self.bn1 = orig_resnet.bn1
+        self.relu1 = orig_resnet.relu1
+        self.conv2 = orig_resnet.conv2
+        self.bn2 = orig_resnet.bn2
+        self.relu2 = orig_resnet.relu2
+        self.conv3 = orig_resnet.conv3
+        self.bn3 = orig_resnet.bn3
+        self.relu3 = orig_resnet.relu3
+        self.maxpool = orig_resnet.maxpool
+        self.layer1 = orig_resnet.layer1
+        self.layer2 = orig_resnet.layer2
+        self.layer3 = orig_resnet.layer3
+        self.layer4 = orig_resnet.layer4
+
+    def _nostride_dilate(self, m, dilate):
+        classname = m.__class__.__name__
+        if classname.find("Conv") != -1:
+            # the convolution with stride
+            if m.stride == (2, 2):
+                m.stride = (1, 1)
+                if m.kernel_size == (3, 3):
+                    m.dilation = (dilate // 2, dilate // 2)
+                    m.padding = (dilate // 2, dilate // 2)
+            # other convoluions
+            else:
+                if m.kernel_size == (3, 3):
+                    m.dilation = (dilate, dilate)
+                    m.padding = (dilate, dilate)
+
+    def forward(self, x, return_feature_maps=False):
+        conv_out = [x]
+        x = self.relu1(self.bn1(self.conv1(x)))
+        x = self.relu2(self.bn2(self.conv2(x)))
+        x = self.relu3(self.bn3(self.conv3(x)))
+        conv_out.append(x)
+        x, indices = self.maxpool(x)
+        x = self.layer1(x)
+        conv_out.append(x)
+        x = self.layer2(x)
+        conv_out.append(x)
+        x = self.layer3(x)
+        conv_out.append(x)
+        x = self.layer4(x)
+        conv_out.append(x)
+
+        if return_feature_maps:
+            return conv_out, indices
+        return [x]
+
+
+class Resnet(nn.Module):
+    def __init__(self, orig_resnet):
+        super(Resnet, self).__init__()
+
+        # take pretrained resnet, except AvgPool and FC
+        self.conv1 = orig_resnet.conv1
+        self.bn1 = orig_resnet.bn1
+        self.relu1 = orig_resnet.relu1
+        self.conv2 = orig_resnet.conv2
+        self.bn2 = orig_resnet.bn2
+        self.relu2 = orig_resnet.relu2
+        self.conv3 = orig_resnet.conv3
+        self.bn3 = orig_resnet.bn3
+        self.relu3 = orig_resnet.relu3
+        self.maxpool = orig_resnet.maxpool
+        self.layer1 = orig_resnet.layer1
+        self.layer2 = orig_resnet.layer2
+        self.layer3 = orig_resnet.layer3
+        self.layer4 = orig_resnet.layer4
+
+    def forward(self, x, return_feature_maps=False):
+        conv_out = []
+
+        x = self.relu1(self.bn1(self.conv1(x)))
+        x = self.relu2(self.bn2(self.conv2(x)))
+        x = self.relu3(self.bn3(self.conv3(x)))
+        conv_out.append(x)
+        x, indices = self.maxpool(x)
+
+        x = self.layer1(x)
+        conv_out.append(x)
+        x = self.layer2(x)
+        conv_out.append(x)
+        x = self.layer3(x)
+        conv_out.append(x)
+        x = self.layer4(x)
+        conv_out.append(x)
+
+        if return_feature_maps:
+            return conv_out
+        return [x]
+
+
+class ResnetDilated(nn.Module):
+    def __init__(self, orig_resnet, dilate_scale=8):
+        super(ResnetDilated, self).__init__()
+
+        if dilate_scale == 8:
+            orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
+            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
+        elif dilate_scale == 16:
+            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
+
+        # take pretrained resnet, except AvgPool and FC
+        self.conv1 = orig_resnet.conv1
+        self.bn1 = orig_resnet.bn1
+        self.relu = orig_resnet.relu
+        self.maxpool = orig_resnet.maxpool
+        self.layer1 = orig_resnet.layer1
+        self.layer2 = orig_resnet.layer2
+        self.layer3 = orig_resnet.layer3
+        self.layer4 = orig_resnet.layer4
+
+    def _nostride_dilate(self, m, dilate):
+        classname = m.__class__.__name__
+        if classname.find("Conv") != -1:
+            # the convolution with stride
+            if m.stride == (2, 2):
+                m.stride = (1, 1)
+                if m.kernel_size == (3, 3):
+                    m.dilation = (dilate // 2, dilate // 2)
+                    m.padding = (dilate // 2, dilate // 2)
+            # other convoluions
+            else:
+                if m.kernel_size == (3, 3):
+                    m.dilation = (dilate, dilate)
+                    m.padding = (dilate, dilate)
+
+    def forward(self, x, return_feature_maps=False):
+        conv_out = [x]
+        x = self.relu(self.bn1(self.conv1(x)))
+        conv_out.append(x)
+        x, indices = self.maxpool(x)
+        x = self.layer1(x)
+        conv_out.append(x)
+        x = self.layer2(x)
+        conv_out.append(x)
+        x = self.layer3(x)
+        conv_out.append(x)
+        x = self.layer4(x)
+        conv_out.append(x)
+
+        if return_feature_maps:
+            return conv_out, indices
+        return [x]
+
+
+def norm(dim, bn=False):
+    if bn is False:
+        return nn.GroupNorm(32, dim)
+    else:
+        return nn.BatchNorm2d(dim)
+
+
+def fba_fusion(alpha, img, F, B):
+    F = alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B
+    B = (1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F
+
+    F = torch.clamp(F, 0, 1)
+    B = torch.clamp(B, 0, 1)
+    la = 0.1
+    alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (
+        torch.sum((F - B) * (F - B), 1, keepdim=True) + la
+    )
+    alpha = torch.clamp(alpha, 0, 1)
+    return alpha, F, B
+
+
+class fba_decoder(nn.Module):
+    def __init__(self, batch_norm=False):
+        super(fba_decoder, self).__init__()
+        pool_scales = (1, 2, 3, 6)
+        self.batch_norm = batch_norm
+
+        self.ppm = []
+
+        for scale in pool_scales:
+            self.ppm.append(
+                nn.Sequential(
+                    nn.AdaptiveAvgPool2d(scale),
+                    L.Conv2d(2048, 256, kernel_size=1, bias=True),
+                    norm(256, self.batch_norm),
+                    nn.LeakyReLU(),
+                )
+            )
+        self.ppm = nn.ModuleList(self.ppm)
+
+        self.conv_up1 = nn.Sequential(
+            L.Conv2d(
+                2048 + len(pool_scales) * 256, 256, kernel_size=3, padding=1, bias=True
+            ),
+            norm(256, self.batch_norm),
+            nn.LeakyReLU(),
+            L.Conv2d(256, 256, kernel_size=3, padding=1),
+            norm(256, self.batch_norm),
+            nn.LeakyReLU(),
+        )
+
+        self.conv_up2 = nn.Sequential(
+            L.Conv2d(256 + 256, 256, kernel_size=3, padding=1, bias=True),
+            norm(256, self.batch_norm),
+            nn.LeakyReLU(),
+        )
+        if self.batch_norm:
+            d_up3 = 128
+        else:
+            d_up3 = 64
+        self.conv_up3 = nn.Sequential(
+            L.Conv2d(256 + d_up3, 64, kernel_size=3, padding=1, bias=True),
+            norm(64, self.batch_norm),
+            nn.LeakyReLU(),
+        )
+
+        self.unpool = nn.MaxUnpool2d(2, stride=2)
+
+        self.conv_up4 = nn.Sequential(
+            nn.Conv2d(64 + 3 + 3 + 2, 32, kernel_size=3, padding=1, bias=True),
+            nn.LeakyReLU(),
+            nn.Conv2d(32, 16, kernel_size=3, padding=1, bias=True),
+            nn.LeakyReLU(),
+            nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True),
+        )
+
+    def forward(self, conv_out, img, indices, two_chan_trimap):
+        conv5 = conv_out[-1]
+
+        input_size = conv5.size()
+        ppm_out = [conv5]
+        for pool_scale in self.ppm:
+            ppm_out.append(
+                nn.functional.interpolate(
+                    pool_scale(conv5),
+                    (input_size[2], input_size[3]),
+                    mode="bilinear",
+                    align_corners=False,
+                )
+            )
+        ppm_out = torch.cat(ppm_out, 1)
+        x = self.conv_up1(ppm_out)
+
+        x = torch.nn.functional.interpolate(
+            x, scale_factor=2, mode="bilinear", align_corners=False
+        )
+
+        x = torch.cat((x, conv_out[-4]), 1)
+
+        x = self.conv_up2(x)
+        x = torch.nn.functional.interpolate(
+            x, scale_factor=2, mode="bilinear", align_corners=False
+        )
+
+        x = torch.cat((x, conv_out[-5]), 1)
+        x = self.conv_up3(x)
+
+        x = torch.nn.functional.interpolate(
+            x, scale_factor=2, mode="bilinear", align_corners=False
+        )
+        x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
+
+        output = self.conv_up4(x)
+
+        alpha = torch.clamp(output[:, 0][:, None], 0, 1)
+        F = torch.sigmoid(output[:, 1:4])
+        B = torch.sigmoid(output[:, 4:7])
+
+        # FBA Fusion
+        alpha, F, B = fba_fusion(alpha, img, F, B)
+
+        output = torch.cat((alpha, F, B), 1)
+
+        return output
+
+
+def build_encoder(arch="resnet50_GN"):
+    if arch == "resnet50_GN_WS":
+        orig_resnet = resnet_GN_WS.__dict__["l_resnet50"]()
+        net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
+    elif arch == "resnet50_BN":
+        orig_resnet = resnet_bn.__dict__["l_resnet50"]()
+        net_encoder = ResnetDilatedBN(orig_resnet, dilate_scale=8)
+
+    else:
+        raise ValueError("Architecture undefined!")
+
+    num_channels = 3 + 6 + 2
+
+    if num_channels > 3:
+        net_encoder_sd = net_encoder.state_dict()
+        conv1_weights = net_encoder_sd["conv1.weight"]
+
+        c_out, c_in, h, w = conv1_weights.size()
+        conv1_mod = torch.zeros(c_out, num_channels, h, w)
+        conv1_mod[:, :3, :, :] = conv1_weights
+
+        conv1 = net_encoder.conv1
+        conv1.in_channels = num_channels
+        conv1.weight = torch.nn.Parameter(conv1_mod)
+
+        net_encoder.conv1 = conv1
+
+        net_encoder_sd["conv1.weight"] = conv1_mod
+
+        net_encoder.load_state_dict(net_encoder_sd)
+    return net_encoder
+
+
+
+
+
+
+
+

Functions

+
+
+def build_encoder(arch='resnet50_GN') +
+
+
+
+ +Expand source code + +
def build_encoder(arch="resnet50_GN"):
+    if arch == "resnet50_GN_WS":
+        orig_resnet = resnet_GN_WS.__dict__["l_resnet50"]()
+        net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
+    elif arch == "resnet50_BN":
+        orig_resnet = resnet_bn.__dict__["l_resnet50"]()
+        net_encoder = ResnetDilatedBN(orig_resnet, dilate_scale=8)
+
+    else:
+        raise ValueError("Architecture undefined!")
+
+    num_channels = 3 + 6 + 2
+
+    if num_channels > 3:
+        net_encoder_sd = net_encoder.state_dict()
+        conv1_weights = net_encoder_sd["conv1.weight"]
+
+        c_out, c_in, h, w = conv1_weights.size()
+        conv1_mod = torch.zeros(c_out, num_channels, h, w)
+        conv1_mod[:, :3, :, :] = conv1_weights
+
+        conv1 = net_encoder.conv1
+        conv1.in_channels = num_channels
+        conv1.weight = torch.nn.Parameter(conv1_mod)
+
+        net_encoder.conv1 = conv1
+
+        net_encoder_sd["conv1.weight"] = conv1_mod
+
+        net_encoder.load_state_dict(net_encoder_sd)
+    return net_encoder
+
+
+
+def fba_fusion(alpha, img, F, B) +
+
+
+
+ +Expand source code + +
def fba_fusion(alpha, img, F, B):
+    F = alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B
+    B = (1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F
+
+    F = torch.clamp(F, 0, 1)
+    B = torch.clamp(B, 0, 1)
+    la = 0.1
+    alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (
+        torch.sum((F - B) * (F - B), 1, keepdim=True) + la
+    )
+    alpha = torch.clamp(alpha, 0, 1)
+    return alpha, F, B
+
+
+
+def norm(dim, bn=False) +
+
+
+
+ +Expand source code + +
def norm(dim, bn=False):
+    if bn is False:
+        return nn.GroupNorm(32, dim)
+    else:
+        return nn.BatchNorm2d(dim)
+
+
+
+
+
+

Classes

+
+
+class FBA +(encoder:Β str) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class FBA(nn.Module):
+    def __init__(self, encoder: str):
+        super(FBA, self).__init__()
+        self.encoder = build_encoder(arch=encoder)
+        self.decoder = fba_decoder(batch_norm=True if "BN" in encoder else False)
+
+    def forward(self, image, two_chan_trimap, image_n, trimap_transformed):
+        resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
+        conv_out, indices = self.encoder(resnet_input, return_feature_maps=True)
+        return self.decoder(conv_out, image, indices, two_chan_trimap)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Methods

+
+
+def forward(self, image, two_chan_trimap, image_n, trimap_transformed) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, image, two_chan_trimap, image_n, trimap_transformed):
+    resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
+    conv_out, indices = self.encoder(resnet_input, return_feature_maps=True)
+    return self.decoder(conv_out, image, indices, two_chan_trimap)
+
+
+
+
+
+class Resnet +(orig_resnet) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Resnet(nn.Module):
+    def __init__(self, orig_resnet):
+        super(Resnet, self).__init__()
+
+        # take pretrained resnet, except AvgPool and FC
+        self.conv1 = orig_resnet.conv1
+        self.bn1 = orig_resnet.bn1
+        self.relu1 = orig_resnet.relu1
+        self.conv2 = orig_resnet.conv2
+        self.bn2 = orig_resnet.bn2
+        self.relu2 = orig_resnet.relu2
+        self.conv3 = orig_resnet.conv3
+        self.bn3 = orig_resnet.bn3
+        self.relu3 = orig_resnet.relu3
+        self.maxpool = orig_resnet.maxpool
+        self.layer1 = orig_resnet.layer1
+        self.layer2 = orig_resnet.layer2
+        self.layer3 = orig_resnet.layer3
+        self.layer4 = orig_resnet.layer4
+
+    def forward(self, x, return_feature_maps=False):
+        conv_out = []
+
+        x = self.relu1(self.bn1(self.conv1(x)))
+        x = self.relu2(self.bn2(self.conv2(x)))
+        x = self.relu3(self.bn3(self.conv3(x)))
+        conv_out.append(x)
+        x, indices = self.maxpool(x)
+
+        x = self.layer1(x)
+        conv_out.append(x)
+        x = self.layer2(x)
+        conv_out.append(x)
+        x = self.layer3(x)
+        conv_out.append(x)
+        x = self.layer4(x)
+        conv_out.append(x)
+
+        if return_feature_maps:
+            return conv_out
+        return [x]
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x, return_feature_maps=False) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x, return_feature_maps=False):
+    conv_out = []
+
+    x = self.relu1(self.bn1(self.conv1(x)))
+    x = self.relu2(self.bn2(self.conv2(x)))
+    x = self.relu3(self.bn3(self.conv3(x)))
+    conv_out.append(x)
+    x, indices = self.maxpool(x)
+
+    x = self.layer1(x)
+    conv_out.append(x)
+    x = self.layer2(x)
+    conv_out.append(x)
+    x = self.layer3(x)
+    conv_out.append(x)
+    x = self.layer4(x)
+    conv_out.append(x)
+
+    if return_feature_maps:
+        return conv_out
+    return [x]
+
+
+
+
+
+class ResnetDilated +(orig_resnet, dilate_scale=8) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ResnetDilated(nn.Module):
+    def __init__(self, orig_resnet, dilate_scale=8):
+        super(ResnetDilated, self).__init__()
+
+        if dilate_scale == 8:
+            orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
+            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
+        elif dilate_scale == 16:
+            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
+
+        # take pretrained resnet, except AvgPool and FC
+        self.conv1 = orig_resnet.conv1
+        self.bn1 = orig_resnet.bn1
+        self.relu = orig_resnet.relu
+        self.maxpool = orig_resnet.maxpool
+        self.layer1 = orig_resnet.layer1
+        self.layer2 = orig_resnet.layer2
+        self.layer3 = orig_resnet.layer3
+        self.layer4 = orig_resnet.layer4
+
+    def _nostride_dilate(self, m, dilate):
+        classname = m.__class__.__name__
+        if classname.find("Conv") != -1:
+            # the convolution with stride
+            if m.stride == (2, 2):
+                m.stride = (1, 1)
+                if m.kernel_size == (3, 3):
+                    m.dilation = (dilate // 2, dilate // 2)
+                    m.padding = (dilate // 2, dilate // 2)
+            # other convoluions
+            else:
+                if m.kernel_size == (3, 3):
+                    m.dilation = (dilate, dilate)
+                    m.padding = (dilate, dilate)
+
+    def forward(self, x, return_feature_maps=False):
+        conv_out = [x]
+        x = self.relu(self.bn1(self.conv1(x)))
+        conv_out.append(x)
+        x, indices = self.maxpool(x)
+        x = self.layer1(x)
+        conv_out.append(x)
+        x = self.layer2(x)
+        conv_out.append(x)
+        x = self.layer3(x)
+        conv_out.append(x)
+        x = self.layer4(x)
+        conv_out.append(x)
+
+        if return_feature_maps:
+            return conv_out, indices
+        return [x]
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x, return_feature_maps=False) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x, return_feature_maps=False):
+    conv_out = [x]
+    x = self.relu(self.bn1(self.conv1(x)))
+    conv_out.append(x)
+    x, indices = self.maxpool(x)
+    x = self.layer1(x)
+    conv_out.append(x)
+    x = self.layer2(x)
+    conv_out.append(x)
+    x = self.layer3(x)
+    conv_out.append(x)
+    x = self.layer4(x)
+    conv_out.append(x)
+
+    if return_feature_maps:
+        return conv_out, indices
+    return [x]
+
+
+
+
+
+class ResnetDilatedBN +(orig_resnet, dilate_scale=8) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ResnetDilatedBN(nn.Module):
+    def __init__(self, orig_resnet, dilate_scale=8):
+        super(ResnetDilatedBN, self).__init__()
+
+        if dilate_scale == 8:
+            orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
+            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
+        elif dilate_scale == 16:
+            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
+
+        # take pretrained resnet, except AvgPool and FC
+        self.conv1 = orig_resnet.conv1
+        self.bn1 = orig_resnet.bn1
+        self.relu1 = orig_resnet.relu1
+        self.conv2 = orig_resnet.conv2
+        self.bn2 = orig_resnet.bn2
+        self.relu2 = orig_resnet.relu2
+        self.conv3 = orig_resnet.conv3
+        self.bn3 = orig_resnet.bn3
+        self.relu3 = orig_resnet.relu3
+        self.maxpool = orig_resnet.maxpool
+        self.layer1 = orig_resnet.layer1
+        self.layer2 = orig_resnet.layer2
+        self.layer3 = orig_resnet.layer3
+        self.layer4 = orig_resnet.layer4
+
+    def _nostride_dilate(self, m, dilate):
+        classname = m.__class__.__name__
+        if classname.find("Conv") != -1:
+            # the convolution with stride
+            if m.stride == (2, 2):
+                m.stride = (1, 1)
+                if m.kernel_size == (3, 3):
+                    m.dilation = (dilate // 2, dilate // 2)
+                    m.padding = (dilate // 2, dilate // 2)
+            # other convoluions
+            else:
+                if m.kernel_size == (3, 3):
+                    m.dilation = (dilate, dilate)
+                    m.padding = (dilate, dilate)
+
+    def forward(self, x, return_feature_maps=False):
+        conv_out = [x]
+        x = self.relu1(self.bn1(self.conv1(x)))
+        x = self.relu2(self.bn2(self.conv2(x)))
+        x = self.relu3(self.bn3(self.conv3(x)))
+        conv_out.append(x)
+        x, indices = self.maxpool(x)
+        x = self.layer1(x)
+        conv_out.append(x)
+        x = self.layer2(x)
+        conv_out.append(x)
+        x = self.layer3(x)
+        conv_out.append(x)
+        x = self.layer4(x)
+        conv_out.append(x)
+
+        if return_feature_maps:
+            return conv_out, indices
+        return [x]
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x, return_feature_maps=False) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x, return_feature_maps=False):
+    conv_out = [x]
+    x = self.relu1(self.bn1(self.conv1(x)))
+    x = self.relu2(self.bn2(self.conv2(x)))
+    x = self.relu3(self.bn3(self.conv3(x)))
+    conv_out.append(x)
+    x, indices = self.maxpool(x)
+    x = self.layer1(x)
+    conv_out.append(x)
+    x = self.layer2(x)
+    conv_out.append(x)
+    x = self.layer3(x)
+    conv_out.append(x)
+    x = self.layer4(x)
+    conv_out.append(x)
+
+    if return_feature_maps:
+        return conv_out, indices
+    return [x]
+
+
+
+
+
+class fba_decoder +(batch_norm=False) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class fba_decoder(nn.Module):
+    def __init__(self, batch_norm=False):
+        super(fba_decoder, self).__init__()
+        pool_scales = (1, 2, 3, 6)
+        self.batch_norm = batch_norm
+
+        self.ppm = []
+
+        for scale in pool_scales:
+            self.ppm.append(
+                nn.Sequential(
+                    nn.AdaptiveAvgPool2d(scale),
+                    L.Conv2d(2048, 256, kernel_size=1, bias=True),
+                    norm(256, self.batch_norm),
+                    nn.LeakyReLU(),
+                )
+            )
+        self.ppm = nn.ModuleList(self.ppm)
+
+        self.conv_up1 = nn.Sequential(
+            L.Conv2d(
+                2048 + len(pool_scales) * 256, 256, kernel_size=3, padding=1, bias=True
+            ),
+            norm(256, self.batch_norm),
+            nn.LeakyReLU(),
+            L.Conv2d(256, 256, kernel_size=3, padding=1),
+            norm(256, self.batch_norm),
+            nn.LeakyReLU(),
+        )
+
+        self.conv_up2 = nn.Sequential(
+            L.Conv2d(256 + 256, 256, kernel_size=3, padding=1, bias=True),
+            norm(256, self.batch_norm),
+            nn.LeakyReLU(),
+        )
+        if self.batch_norm:
+            d_up3 = 128
+        else:
+            d_up3 = 64
+        self.conv_up3 = nn.Sequential(
+            L.Conv2d(256 + d_up3, 64, kernel_size=3, padding=1, bias=True),
+            norm(64, self.batch_norm),
+            nn.LeakyReLU(),
+        )
+
+        self.unpool = nn.MaxUnpool2d(2, stride=2)
+
+        self.conv_up4 = nn.Sequential(
+            nn.Conv2d(64 + 3 + 3 + 2, 32, kernel_size=3, padding=1, bias=True),
+            nn.LeakyReLU(),
+            nn.Conv2d(32, 16, kernel_size=3, padding=1, bias=True),
+            nn.LeakyReLU(),
+            nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True),
+        )
+
+    def forward(self, conv_out, img, indices, two_chan_trimap):
+        conv5 = conv_out[-1]
+
+        input_size = conv5.size()
+        ppm_out = [conv5]
+        for pool_scale in self.ppm:
+            ppm_out.append(
+                nn.functional.interpolate(
+                    pool_scale(conv5),
+                    (input_size[2], input_size[3]),
+                    mode="bilinear",
+                    align_corners=False,
+                )
+            )
+        ppm_out = torch.cat(ppm_out, 1)
+        x = self.conv_up1(ppm_out)
+
+        x = torch.nn.functional.interpolate(
+            x, scale_factor=2, mode="bilinear", align_corners=False
+        )
+
+        x = torch.cat((x, conv_out[-4]), 1)
+
+        x = self.conv_up2(x)
+        x = torch.nn.functional.interpolate(
+            x, scale_factor=2, mode="bilinear", align_corners=False
+        )
+
+        x = torch.cat((x, conv_out[-5]), 1)
+        x = self.conv_up3(x)
+
+        x = torch.nn.functional.interpolate(
+            x, scale_factor=2, mode="bilinear", align_corners=False
+        )
+        x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
+
+        output = self.conv_up4(x)
+
+        alpha = torch.clamp(output[:, 0][:, None], 0, 1)
+        F = torch.sigmoid(output[:, 1:4])
+        B = torch.sigmoid(output[:, 4:7])
+
+        # FBA Fusion
+        alpha, F, B = fba_fusion(alpha, img, F, B)
+
+        output = torch.cat((alpha, F, B), 1)
+
+        return output
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, conv_out, img, indices, two_chan_trimap) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, conv_out, img, indices, two_chan_trimap):
+    conv5 = conv_out[-1]
+
+    input_size = conv5.size()
+    ppm_out = [conv5]
+    for pool_scale in self.ppm:
+        ppm_out.append(
+            nn.functional.interpolate(
+                pool_scale(conv5),
+                (input_size[2], input_size[3]),
+                mode="bilinear",
+                align_corners=False,
+            )
+        )
+    ppm_out = torch.cat(ppm_out, 1)
+    x = self.conv_up1(ppm_out)
+
+    x = torch.nn.functional.interpolate(
+        x, scale_factor=2, mode="bilinear", align_corners=False
+    )
+
+    x = torch.cat((x, conv_out[-4]), 1)
+
+    x = self.conv_up2(x)
+    x = torch.nn.functional.interpolate(
+        x, scale_factor=2, mode="bilinear", align_corners=False
+    )
+
+    x = torch.cat((x, conv_out[-5]), 1)
+    x = self.conv_up3(x)
+
+    x = torch.nn.functional.interpolate(
+        x, scale_factor=2, mode="bilinear", align_corners=False
+    )
+    x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
+
+    output = self.conv_up4(x)
+
+    alpha = torch.clamp(output[:, 0][:, None], 0, 1)
+    F = torch.sigmoid(output[:, 1:4])
+    B = torch.sigmoid(output[:, 4:7])
+
+    # FBA Fusion
+    alpha, F, B = fba_fusion(alpha, img, F, B)
+
+    output = torch.cat((alpha, F, B), 1)
+
+    return output
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/fba_matting/resnet_GN_WS.html b/docs/api/carvekit/ml/arch/fba_matting/resnet_GN_WS.html new file mode 100644 index 0000000..3c25b9b --- /dev/null +++ b/docs/api/carvekit/ml/arch/fba_matting/resnet_GN_WS.html @@ -0,0 +1,388 @@ + + + + + + +carvekit.ml.arch.fba_matting.resnet_GN_WS API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.fba_matting.resnet_GN_WS

+
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/MarcoForte/FBA_Matting +License: MIT License

+
+ +Expand source code + +
"""
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+Source url: https://github.com/MarcoForte/FBA_Matting
+License: MIT License
+"""
+import torch.nn as nn
+import carvekit.ml.arch.fba_matting.layers_WS as L
+
+__all__ = ["ResNet", "l_resnet50"]
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return L.Conv2d(
+        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
+    )
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = L.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = L.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = conv1x1(inplanes, planes)
+        self.bn1 = L.BatchNorm2d(planes)
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bn2 = L.BatchNorm2d(planes)
+        self.conv3 = conv1x1(planes, planes * self.expansion)
+        self.bn3 = L.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class ResNet(nn.Module):
+    def __init__(self, block, layers, num_classes=1000):
+        super(ResNet, self).__init__()
+        self.inplanes = 64
+        self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+        self.bn1 = L.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(
+            kernel_size=3, stride=2, padding=1, return_indices=True
+        )
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                L.BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+
+        return x
+
+
+def l_resnet50(pretrained=False, **kwargs):
+    """Constructs a ResNet-50 model.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+    return model
+
+
+
+
+
+
+
+

Functions

+
+
+def l_resnet50(pretrained=False, **kwargs) +
+
+

Constructs a ResNet-50 model.

+

Args

+
+
pretrained : bool
+
If True, returns a model pre-trained on ImageNet
+
+
+ +Expand source code + +
def l_resnet50(pretrained=False, **kwargs):
+    """Constructs a ResNet-50 model.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+    return model
+
+
+
+
+
+

Classes

+
+
+class ResNet +(block, layers, num_classes=1000) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ResNet(nn.Module):
+    def __init__(self, block, layers, num_classes=1000):
+        super(ResNet, self).__init__()
+        self.inplanes = 64
+        self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+        self.bn1 = L.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(
+            kernel_size=3, stride=2, padding=1, return_indices=True
+        )
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                L.BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.conv1(x)
+    x = self.bn1(x)
+    x = self.relu(x)
+    x = self.maxpool(x)
+
+    x = self.layer1(x)
+    x = self.layer2(x)
+    x = self.layer3(x)
+    x = self.layer4(x)
+
+    x = self.avgpool(x)
+    x = x.view(x.size(0), -1)
+    x = self.fc(x)
+
+    return x
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/fba_matting/resnet_bn.html b/docs/api/carvekit/ml/arch/fba_matting/resnet_bn.html new file mode 100644 index 0000000..cb43b4c --- /dev/null +++ b/docs/api/carvekit/ml/arch/fba_matting/resnet_bn.html @@ -0,0 +1,394 @@ + + + + + + +carvekit.ml.arch.fba_matting.resnet_bn API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.fba_matting.resnet_bn

+
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/MarcoForte/FBA_Matting +License: MIT License

+
+ +Expand source code + +
"""
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+Source url: https://github.com/MarcoForte/FBA_Matting
+License: MIT License
+"""
+import torch.nn as nn
+import math
+from torch.nn import BatchNorm2d
+
+__all__ = ["ResNet"]
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    "3x3 convolution with padding"
+    return nn.Conv2d(
+        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
+    )
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(
+            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
+        )
+        self.bn2 = BatchNorm2d(planes, momentum=0.01)
+        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+        self.bn3 = BatchNorm2d(planes * 4)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class ResNet(nn.Module):
+    def __init__(self, block, layers, num_classes=1000):
+        self.inplanes = 128
+        super(ResNet, self).__init__()
+        self.conv1 = conv3x3(3, 64, stride=2)
+        self.bn1 = BatchNorm2d(64)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(64, 64)
+        self.bn2 = BatchNorm2d(64)
+        self.relu2 = nn.ReLU(inplace=True)
+        self.conv3 = conv3x3(64, 128)
+        self.bn3 = BatchNorm2d(128)
+        self.relu3 = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(
+            kernel_size=3, stride=2, padding=1, return_indices=True
+        )
+
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+        self.avgpool = nn.AvgPool2d(7, stride=1)
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2.0 / n))
+            elif isinstance(m, BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(
+                    self.inplanes,
+                    planes * block.expansion,
+                    kernel_size=1,
+                    stride=stride,
+                    bias=False,
+                ),
+                BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.relu1(self.bn1(self.conv1(x)))
+        x = self.relu2(self.bn2(self.conv2(x)))
+        x = self.relu3(self.bn3(self.conv3(x)))
+        x, indices = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+        return x
+
+
+def l_resnet50():
+    """Constructs a ResNet-50 model.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 6, 3])
+    return model
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class ResNet +(block, layers, num_classes=1000) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ResNet(nn.Module):
+    def __init__(self, block, layers, num_classes=1000):
+        self.inplanes = 128
+        super(ResNet, self).__init__()
+        self.conv1 = conv3x3(3, 64, stride=2)
+        self.bn1 = BatchNorm2d(64)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(64, 64)
+        self.bn2 = BatchNorm2d(64)
+        self.relu2 = nn.ReLU(inplace=True)
+        self.conv3 = conv3x3(64, 128)
+        self.bn3 = BatchNorm2d(128)
+        self.relu3 = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(
+            kernel_size=3, stride=2, padding=1, return_indices=True
+        )
+
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+        self.avgpool = nn.AvgPool2d(7, stride=1)
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2.0 / n))
+            elif isinstance(m, BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(
+                    self.inplanes,
+                    planes * block.expansion,
+                    kernel_size=1,
+                    stride=stride,
+                    bias=False,
+                ),
+                BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.relu1(self.bn1(self.conv1(x)))
+        x = self.relu2(self.bn2(self.conv2(x)))
+        x = self.relu3(self.bn3(self.conv3(x)))
+        x, indices = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.relu1(self.bn1(self.conv1(x)))
+    x = self.relu2(self.bn2(self.conv2(x)))
+    x = self.relu3(self.bn3(self.conv3(x)))
+    x, indices = self.maxpool(x)
+
+    x = self.layer1(x)
+    x = self.layer2(x)
+    x = self.layer3(x)
+    x = self.layer4(x)
+
+    x = self.avgpool(x)
+    x = x.view(x.size(0), -1)
+    x = self.fc(x)
+    return x
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/fba_matting/transforms.html b/docs/api/carvekit/ml/arch/fba_matting/transforms.html new file mode 100644 index 0000000..55fa4d2 --- /dev/null +++ b/docs/api/carvekit/ml/arch/fba_matting/transforms.html @@ -0,0 +1,180 @@ + + + + + + +carvekit.ml.arch.fba_matting.transforms API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.fba_matting.transforms

+
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/MarcoForte/FBA_Matting +License: MIT License

+
+ +Expand source code + +
"""
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+Source url: https://github.com/MarcoForte/FBA_Matting
+License: MIT License
+"""
+import cv2
+import numpy as np
+
+group_norm_std = [0.229, 0.224, 0.225]
+group_norm_mean = [0.485, 0.456, 0.406]
+
+
+def dt(a):
+    return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)
+
+
+def trimap_transform(trimap):
+    h, w = trimap.shape[0], trimap.shape[1]
+
+    clicks = np.zeros((h, w, 6))
+    for k in range(2):
+        if np.count_nonzero(trimap[:, :, k]) > 0:
+            dt_mask = -dt(1 - trimap[:, :, k]) ** 2
+            L = 320
+            clicks[:, :, 3 * k] = np.exp(dt_mask / (2 * ((0.02 * L) ** 2)))
+            clicks[:, :, 3 * k + 1] = np.exp(dt_mask / (2 * ((0.08 * L) ** 2)))
+            clicks[:, :, 3 * k + 2] = np.exp(dt_mask / (2 * ((0.16 * L) ** 2)))
+
+    return clicks
+
+
+def groupnorm_normalise_image(img, format="nhwc"):
+    """
+    Accept rgb in range 0,1
+    """
+    if format == "nhwc":
+        for i in range(3):
+            img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i]
+    else:
+        for i in range(3):
+            img[..., i, :, :] = (
+                img[..., i, :, :] - group_norm_mean[i]
+            ) / group_norm_std[i]
+
+    return img
+
+
+
+
+
+
+
+

Functions

+
+
+def dt(a) +
+
+
+
+ +Expand source code + +
def dt(a):
+    return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)
+
+
+
+def groupnorm_normalise_image(img, format='nhwc') +
+
+

Accept rgb in range 0,1

+
+ +Expand source code + +
def groupnorm_normalise_image(img, format="nhwc"):
+    """
+    Accept rgb in range 0,1
+    """
+    if format == "nhwc":
+        for i in range(3):
+            img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i]
+    else:
+        for i in range(3):
+            img[..., i, :, :] = (
+                img[..., i, :, :] - group_norm_mean[i]
+            ) / group_norm_std[i]
+
+    return img
+
+
+
+def trimap_transform(trimap) +
+
+
+
+ +Expand source code + +
def trimap_transform(trimap):
+    h, w = trimap.shape[0], trimap.shape[1]
+
+    clicks = np.zeros((h, w, 6))
+    for k in range(2):
+        if np.count_nonzero(trimap[:, :, k]) > 0:
+            dt_mask = -dt(1 - trimap[:, :, k]) ** 2
+            L = 320
+            clicks[:, :, 3 * k] = np.exp(dt_mask / (2 * ((0.02 * L) ** 2)))
+            clicks[:, :, 3 * k + 1] = np.exp(dt_mask / (2 * ((0.08 * L) ** 2)))
+            clicks[:, :, 3 * k + 2] = np.exp(dt_mask / (2 * ((0.16 * L) ** 2)))
+
+    return clicks
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/index.html b/docs/api/carvekit/ml/arch/index.html new file mode 100644 index 0000000..a336cab --- /dev/null +++ b/docs/api/carvekit/ml/arch/index.html @@ -0,0 +1,90 @@ + + + + + + +carvekit.ml.arch API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch

+
+
+
+
+

Sub-modules

+
+
carvekit.ml.arch.basnet
+
+
+
+
carvekit.ml.arch.cascadepsp
+
+
+
+
carvekit.ml.arch.fba_matting
+
+
+
+
carvekit.ml.arch.tracerb7
+
+
+
+
carvekit.ml.arch.u2net
+
+
+
+
carvekit.ml.arch.yolov4
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/tracerb7/att_modules.html b/docs/api/carvekit/ml/arch/tracerb7/att_modules.html new file mode 100644 index 0000000..38c59fd --- /dev/null +++ b/docs/api/carvekit/ml/arch/tracerb7/att_modules.html @@ -0,0 +1,1126 @@ + + + + + + +carvekit.ml.arch.tracerb7.att_modules API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.tracerb7.att_modules

+
+
+

Source url: https://github.com/Karel911/TRACER +Author: Min Seok Lee and Wooseok Shin +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/Karel911/TRACER
+Author: Min Seok Lee and Wooseok Shin
+License: Apache License 2.0
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from carvekit.ml.arch.tracerb7.conv_modules import BasicConv2d, DWConv, DWSConv
+
+
+class RFB_Block(nn.Module):
+    def __init__(self, in_channel, out_channel):
+        super(RFB_Block, self).__init__()
+        self.relu = nn.ReLU(True)
+        self.branch0 = nn.Sequential(
+            BasicConv2d(in_channel, out_channel, 1),
+        )
+        self.branch1 = nn.Sequential(
+            BasicConv2d(in_channel, out_channel, 1),
+            BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
+            BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
+            BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3),
+        )
+        self.branch2 = nn.Sequential(
+            BasicConv2d(in_channel, out_channel, 1),
+            BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
+            BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
+            BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5),
+        )
+        self.branch3 = nn.Sequential(
+            BasicConv2d(in_channel, out_channel, 1),
+            BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
+            BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
+            BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7),
+        )
+        self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1)
+        self.conv_res = BasicConv2d(in_channel, out_channel, 1)
+
+    def forward(self, x):
+        x0 = self.branch0(x)
+        x1 = self.branch1(x)
+        x2 = self.branch2(x)
+        x3 = self.branch3(x)
+        x_cat = torch.cat((x0, x1, x2, x3), 1)
+        x_cat = self.conv_cat(x_cat)
+
+        x = self.relu(x_cat + self.conv_res(x))
+        return x
+
+
+class GlobalAvgPool(nn.Module):
+    def __init__(self, flatten=False):
+        super(GlobalAvgPool, self).__init__()
+        self.flatten = flatten
+
+    def forward(self, x):
+        if self.flatten:
+            in_size = x.size()
+            return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
+        else:
+            return (
+                x.view(x.size(0), x.size(1), -1)
+                .mean(-1)
+                .view(x.size(0), x.size(1), 1, 1)
+            )
+
+
+class UnionAttentionModule(nn.Module):
+    def __init__(self, n_channels, only_channel_tracing=False):
+        super(UnionAttentionModule, self).__init__()
+        self.GAP = GlobalAvgPool()
+        self.confidence_ratio = 0.1
+        self.bn = nn.BatchNorm2d(n_channels)
+        self.norm = nn.Sequential(
+            nn.BatchNorm2d(n_channels), nn.Dropout3d(self.confidence_ratio)
+        )
+        self.channel_q = nn.Conv2d(
+            in_channels=n_channels,
+            out_channels=n_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=False,
+        )
+        self.channel_k = nn.Conv2d(
+            in_channels=n_channels,
+            out_channels=n_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=False,
+        )
+        self.channel_v = nn.Conv2d(
+            in_channels=n_channels,
+            out_channels=n_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=False,
+        )
+
+        self.fc = nn.Conv2d(
+            in_channels=n_channels,
+            out_channels=n_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=False,
+        )
+
+        if only_channel_tracing is False:
+            self.spatial_q = nn.Conv2d(
+                in_channels=n_channels,
+                out_channels=1,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=False,
+            )
+            self.spatial_k = nn.Conv2d(
+                in_channels=n_channels,
+                out_channels=1,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=False,
+            )
+            self.spatial_v = nn.Conv2d(
+                in_channels=n_channels,
+                out_channels=1,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=False,
+            )
+        self.sigmoid = nn.Sigmoid()
+
+    def masking(self, x, mask):
+        mask = mask.squeeze(3).squeeze(2)
+        threshold = torch.quantile(
+            mask.float(), self.confidence_ratio, dim=-1, keepdim=True
+        )
+        mask[mask <= threshold] = 0.0
+        mask = mask.unsqueeze(2).unsqueeze(3)
+        mask = mask.expand(-1, x.shape[1], x.shape[2], x.shape[3]).contiguous()
+        masked_x = x * mask
+
+        return masked_x
+
+    def Channel_Tracer(self, x):
+        avg_pool = self.GAP(x)
+        x_norm = self.norm(avg_pool)
+
+        q = self.channel_q(x_norm).squeeze(-1)
+        k = self.channel_k(x_norm).squeeze(-1)
+        v = self.channel_v(x_norm).squeeze(-1)
+
+        # softmax(Q*K^T)
+        QK_T = torch.matmul(q, k.transpose(1, 2))
+        alpha = F.softmax(QK_T, dim=-1)
+
+        # a*v
+        att = torch.matmul(alpha, v).unsqueeze(-1)
+        att = self.fc(att)
+        att = self.sigmoid(att)
+
+        output = (x * att) + x
+        alpha_mask = att.clone()
+
+        return output, alpha_mask
+
+    def forward(self, x):
+        X_c, alpha_mask = self.Channel_Tracer(x)
+        X_c = self.bn(X_c)
+        x_drop = self.masking(X_c, alpha_mask)
+
+        q = self.spatial_q(x_drop).squeeze(1)
+        k = self.spatial_k(x_drop).squeeze(1)
+        v = self.spatial_v(x_drop).squeeze(1)
+
+        # softmax(Q*K^T)
+        QK_T = torch.matmul(q, k.transpose(1, 2))
+        alpha = F.softmax(QK_T, dim=-1)
+
+        output = torch.matmul(alpha, v).unsqueeze(1) + v.unsqueeze(1)
+
+        return output
+
+
+class aggregation(nn.Module):
+    def __init__(self, channel):
+        super(aggregation, self).__init__()
+        self.relu = nn.ReLU(True)
+
+        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
+        self.conv_upsample1 = BasicConv2d(channel[2], channel[1], 3, padding=1)
+        self.conv_upsample2 = BasicConv2d(channel[2], channel[0], 3, padding=1)
+        self.conv_upsample3 = BasicConv2d(channel[1], channel[0], 3, padding=1)
+        self.conv_upsample4 = BasicConv2d(channel[2], channel[2], 3, padding=1)
+        self.conv_upsample5 = BasicConv2d(
+            channel[2] + channel[1], channel[2] + channel[1], 3, padding=1
+        )
+
+        self.conv_concat2 = BasicConv2d(
+            (channel[2] + channel[1]), (channel[2] + channel[1]), 3, padding=1
+        )
+        self.conv_concat3 = BasicConv2d(
+            (channel[0] + channel[1] + channel[2]),
+            (channel[0] + channel[1] + channel[2]),
+            3,
+            padding=1,
+        )
+
+        self.UAM = UnionAttentionModule(channel[0] + channel[1] + channel[2])
+
+    def forward(self, e4, e3, e2):
+        e4_1 = e4
+        e3_1 = self.conv_upsample1(self.upsample(e4)) * e3
+        e2_1 = (
+            self.conv_upsample2(self.upsample(self.upsample(e4)))
+            * self.conv_upsample3(self.upsample(e3))
+            * e2
+        )
+
+        e3_2 = torch.cat((e3_1, self.conv_upsample4(self.upsample(e4_1))), 1)
+        e3_2 = self.conv_concat2(e3_2)
+
+        e2_2 = torch.cat((e2_1, self.conv_upsample5(self.upsample(e3_2))), 1)
+        x = self.conv_concat3(e2_2)
+
+        output = self.UAM(x)
+
+        return output
+
+
+class ObjectAttention(nn.Module):
+    def __init__(self, channel, kernel_size):
+        super(ObjectAttention, self).__init__()
+        self.channel = channel
+        self.DWSConv = DWSConv(
+            channel, channel // 2, kernel=kernel_size, padding=1, kernels_per_layer=1
+        )
+        self.DWConv1 = nn.Sequential(
+            DWConv(channel // 2, channel // 2, kernel=1, padding=0, dilation=1),
+            BasicConv2d(channel // 2, channel // 8, 1),
+        )
+        self.DWConv2 = nn.Sequential(
+            DWConv(channel // 2, channel // 2, kernel=3, padding=1, dilation=1),
+            BasicConv2d(channel // 2, channel // 8, 1),
+        )
+        self.DWConv3 = nn.Sequential(
+            DWConv(channel // 2, channel // 2, kernel=3, padding=3, dilation=3),
+            BasicConv2d(channel // 2, channel // 8, 1),
+        )
+        self.DWConv4 = nn.Sequential(
+            DWConv(channel // 2, channel // 2, kernel=3, padding=5, dilation=5),
+            BasicConv2d(channel // 2, channel // 8, 1),
+        )
+        self.conv1 = BasicConv2d(channel // 2, 1, 1)
+
+    def forward(self, decoder_map, encoder_map):
+        """
+        Args:
+            decoder_map: decoder representation (B, 1, H, W).
+            encoder_map: encoder block output (B, C, H, W).
+        Returns:
+            decoder representation: (B, 1, H, W)
+        """
+        mask_bg = -1 * torch.sigmoid(decoder_map) + 1  # Sigmoid & Reverse
+        mask_ob = torch.sigmoid(decoder_map)  # object attention
+        x = mask_ob.expand(-1, self.channel, -1, -1).mul(encoder_map)
+
+        edge = mask_bg.clone()
+        edge[edge > 0.93] = 0
+        x = x + (edge * encoder_map)
+
+        x = self.DWSConv(x)
+        skip = x.clone()
+        x = (
+            torch.cat(
+                [self.DWConv1(x), self.DWConv2(x), self.DWConv3(x), self.DWConv4(x)],
+                dim=1,
+            )
+            + skip
+        )
+        x = torch.relu(self.conv1(x))
+
+        return x + decoder_map
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class GlobalAvgPool +(flatten=False) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class GlobalAvgPool(nn.Module):
+    def __init__(self, flatten=False):
+        super(GlobalAvgPool, self).__init__()
+        self.flatten = flatten
+
+    def forward(self, x):
+        if self.flatten:
+            in_size = x.size()
+            return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
+        else:
+            return (
+                x.view(x.size(0), x.size(1), -1)
+                .mean(-1)
+                .view(x.size(0), x.size(1), 1, 1)
+            )
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    if self.flatten:
+        in_size = x.size()
+        return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
+    else:
+        return (
+            x.view(x.size(0), x.size(1), -1)
+            .mean(-1)
+            .view(x.size(0), x.size(1), 1, 1)
+        )
+
+
+
+
+
+class ObjectAttention +(channel, kernel_size) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ObjectAttention(nn.Module):
+    def __init__(self, channel, kernel_size):
+        super(ObjectAttention, self).__init__()
+        self.channel = channel
+        self.DWSConv = DWSConv(
+            channel, channel // 2, kernel=kernel_size, padding=1, kernels_per_layer=1
+        )
+        self.DWConv1 = nn.Sequential(
+            DWConv(channel // 2, channel // 2, kernel=1, padding=0, dilation=1),
+            BasicConv2d(channel // 2, channel // 8, 1),
+        )
+        self.DWConv2 = nn.Sequential(
+            DWConv(channel // 2, channel // 2, kernel=3, padding=1, dilation=1),
+            BasicConv2d(channel // 2, channel // 8, 1),
+        )
+        self.DWConv3 = nn.Sequential(
+            DWConv(channel // 2, channel // 2, kernel=3, padding=3, dilation=3),
+            BasicConv2d(channel // 2, channel // 8, 1),
+        )
+        self.DWConv4 = nn.Sequential(
+            DWConv(channel // 2, channel // 2, kernel=3, padding=5, dilation=5),
+            BasicConv2d(channel // 2, channel // 8, 1),
+        )
+        self.conv1 = BasicConv2d(channel // 2, 1, 1)
+
+    def forward(self, decoder_map, encoder_map):
+        """
+        Args:
+            decoder_map: decoder representation (B, 1, H, W).
+            encoder_map: encoder block output (B, C, H, W).
+        Returns:
+            decoder representation: (B, 1, H, W)
+        """
+        mask_bg = -1 * torch.sigmoid(decoder_map) + 1  # Sigmoid & Reverse
+        mask_ob = torch.sigmoid(decoder_map)  # object attention
+        x = mask_ob.expand(-1, self.channel, -1, -1).mul(encoder_map)
+
+        edge = mask_bg.clone()
+        edge[edge > 0.93] = 0
+        x = x + (edge * encoder_map)
+
+        x = self.DWSConv(x)
+        skip = x.clone()
+        x = (
+            torch.cat(
+                [self.DWConv1(x), self.DWConv2(x), self.DWConv3(x), self.DWConv4(x)],
+                dim=1,
+            )
+            + skip
+        )
+        x = torch.relu(self.conv1(x))
+
+        return x + decoder_map
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, decoder_map, encoder_map) ‑>Β Callable[...,Β Any] +
+
+

Args

+
+
decoder_map
+
decoder representation (B, 1, H, W).
+
encoder_map
+
encoder block output (B, C, H, W).
+
+

Returns

+
+
decoder representation
+
(B, 1, H, W)
+
+
+ +Expand source code + +
def forward(self, decoder_map, encoder_map):
+    """
+    Args:
+        decoder_map: decoder representation (B, 1, H, W).
+        encoder_map: encoder block output (B, C, H, W).
+    Returns:
+        decoder representation: (B, 1, H, W)
+    """
+    mask_bg = -1 * torch.sigmoid(decoder_map) + 1  # Sigmoid & Reverse
+    mask_ob = torch.sigmoid(decoder_map)  # object attention
+    x = mask_ob.expand(-1, self.channel, -1, -1).mul(encoder_map)
+
+    edge = mask_bg.clone()
+    edge[edge > 0.93] = 0
+    x = x + (edge * encoder_map)
+
+    x = self.DWSConv(x)
+    skip = x.clone()
+    x = (
+        torch.cat(
+            [self.DWConv1(x), self.DWConv2(x), self.DWConv3(x), self.DWConv4(x)],
+            dim=1,
+        )
+        + skip
+    )
+    x = torch.relu(self.conv1(x))
+
+    return x + decoder_map
+
+
+
+
+
+class RFB_Block +(in_channel, out_channel) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class RFB_Block(nn.Module):
+    def __init__(self, in_channel, out_channel):
+        super(RFB_Block, self).__init__()
+        self.relu = nn.ReLU(True)
+        self.branch0 = nn.Sequential(
+            BasicConv2d(in_channel, out_channel, 1),
+        )
+        self.branch1 = nn.Sequential(
+            BasicConv2d(in_channel, out_channel, 1),
+            BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
+            BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
+            BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3),
+        )
+        self.branch2 = nn.Sequential(
+            BasicConv2d(in_channel, out_channel, 1),
+            BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
+            BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
+            BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5),
+        )
+        self.branch3 = nn.Sequential(
+            BasicConv2d(in_channel, out_channel, 1),
+            BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
+            BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
+            BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7),
+        )
+        self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1)
+        self.conv_res = BasicConv2d(in_channel, out_channel, 1)
+
+    def forward(self, x):
+        x0 = self.branch0(x)
+        x1 = self.branch1(x)
+        x2 = self.branch2(x)
+        x3 = self.branch3(x)
+        x_cat = torch.cat((x0, x1, x2, x3), 1)
+        x_cat = self.conv_cat(x_cat)
+
+        x = self.relu(x_cat + self.conv_res(x))
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x0 = self.branch0(x)
+    x1 = self.branch1(x)
+    x2 = self.branch2(x)
+    x3 = self.branch3(x)
+    x_cat = torch.cat((x0, x1, x2, x3), 1)
+    x_cat = self.conv_cat(x_cat)
+
+    x = self.relu(x_cat + self.conv_res(x))
+    return x
+
+
+
+
+
+class UnionAttentionModule +(n_channels, only_channel_tracing=False) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class UnionAttentionModule(nn.Module):
+    def __init__(self, n_channels, only_channel_tracing=False):
+        super(UnionAttentionModule, self).__init__()
+        self.GAP = GlobalAvgPool()
+        self.confidence_ratio = 0.1
+        self.bn = nn.BatchNorm2d(n_channels)
+        self.norm = nn.Sequential(
+            nn.BatchNorm2d(n_channels), nn.Dropout3d(self.confidence_ratio)
+        )
+        self.channel_q = nn.Conv2d(
+            in_channels=n_channels,
+            out_channels=n_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=False,
+        )
+        self.channel_k = nn.Conv2d(
+            in_channels=n_channels,
+            out_channels=n_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=False,
+        )
+        self.channel_v = nn.Conv2d(
+            in_channels=n_channels,
+            out_channels=n_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=False,
+        )
+
+        self.fc = nn.Conv2d(
+            in_channels=n_channels,
+            out_channels=n_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=False,
+        )
+
+        if only_channel_tracing is False:
+            self.spatial_q = nn.Conv2d(
+                in_channels=n_channels,
+                out_channels=1,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=False,
+            )
+            self.spatial_k = nn.Conv2d(
+                in_channels=n_channels,
+                out_channels=1,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=False,
+            )
+            self.spatial_v = nn.Conv2d(
+                in_channels=n_channels,
+                out_channels=1,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=False,
+            )
+        self.sigmoid = nn.Sigmoid()
+
+    def masking(self, x, mask):
+        mask = mask.squeeze(3).squeeze(2)
+        threshold = torch.quantile(
+            mask.float(), self.confidence_ratio, dim=-1, keepdim=True
+        )
+        mask[mask <= threshold] = 0.0
+        mask = mask.unsqueeze(2).unsqueeze(3)
+        mask = mask.expand(-1, x.shape[1], x.shape[2], x.shape[3]).contiguous()
+        masked_x = x * mask
+
+        return masked_x
+
+    def Channel_Tracer(self, x):
+        avg_pool = self.GAP(x)
+        x_norm = self.norm(avg_pool)
+
+        q = self.channel_q(x_norm).squeeze(-1)
+        k = self.channel_k(x_norm).squeeze(-1)
+        v = self.channel_v(x_norm).squeeze(-1)
+
+        # softmax(Q*K^T)
+        QK_T = torch.matmul(q, k.transpose(1, 2))
+        alpha = F.softmax(QK_T, dim=-1)
+
+        # a*v
+        att = torch.matmul(alpha, v).unsqueeze(-1)
+        att = self.fc(att)
+        att = self.sigmoid(att)
+
+        output = (x * att) + x
+        alpha_mask = att.clone()
+
+        return output, alpha_mask
+
+    def forward(self, x):
+        X_c, alpha_mask = self.Channel_Tracer(x)
+        X_c = self.bn(X_c)
+        x_drop = self.masking(X_c, alpha_mask)
+
+        q = self.spatial_q(x_drop).squeeze(1)
+        k = self.spatial_k(x_drop).squeeze(1)
+        v = self.spatial_v(x_drop).squeeze(1)
+
+        # softmax(Q*K^T)
+        QK_T = torch.matmul(q, k.transpose(1, 2))
+        alpha = F.softmax(QK_T, dim=-1)
+
+        output = torch.matmul(alpha, v).unsqueeze(1) + v.unsqueeze(1)
+
+        return output
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def Channel_Tracer(self, x) +
+
+
+
+ +Expand source code + +
def Channel_Tracer(self, x):
+    avg_pool = self.GAP(x)
+    x_norm = self.norm(avg_pool)
+
+    q = self.channel_q(x_norm).squeeze(-1)
+    k = self.channel_k(x_norm).squeeze(-1)
+    v = self.channel_v(x_norm).squeeze(-1)
+
+    # softmax(Q*K^T)
+    QK_T = torch.matmul(q, k.transpose(1, 2))
+    alpha = F.softmax(QK_T, dim=-1)
+
+    # a*v
+    att = torch.matmul(alpha, v).unsqueeze(-1)
+    att = self.fc(att)
+    att = self.sigmoid(att)
+
+    output = (x * att) + x
+    alpha_mask = att.clone()
+
+    return output, alpha_mask
+
+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    X_c, alpha_mask = self.Channel_Tracer(x)
+    X_c = self.bn(X_c)
+    x_drop = self.masking(X_c, alpha_mask)
+
+    q = self.spatial_q(x_drop).squeeze(1)
+    k = self.spatial_k(x_drop).squeeze(1)
+    v = self.spatial_v(x_drop).squeeze(1)
+
+    # softmax(Q*K^T)
+    QK_T = torch.matmul(q, k.transpose(1, 2))
+    alpha = F.softmax(QK_T, dim=-1)
+
+    output = torch.matmul(alpha, v).unsqueeze(1) + v.unsqueeze(1)
+
+    return output
+
+
+
+def masking(self, x, mask) +
+
+
+
+ +Expand source code + +
def masking(self, x, mask):
+    mask = mask.squeeze(3).squeeze(2)
+    threshold = torch.quantile(
+        mask.float(), self.confidence_ratio, dim=-1, keepdim=True
+    )
+    mask[mask <= threshold] = 0.0
+    mask = mask.unsqueeze(2).unsqueeze(3)
+    mask = mask.expand(-1, x.shape[1], x.shape[2], x.shape[3]).contiguous()
+    masked_x = x * mask
+
+    return masked_x
+
+
+
+
+
+class aggregation +(channel) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class aggregation(nn.Module):
+    def __init__(self, channel):
+        super(aggregation, self).__init__()
+        self.relu = nn.ReLU(True)
+
+        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
+        self.conv_upsample1 = BasicConv2d(channel[2], channel[1], 3, padding=1)
+        self.conv_upsample2 = BasicConv2d(channel[2], channel[0], 3, padding=1)
+        self.conv_upsample3 = BasicConv2d(channel[1], channel[0], 3, padding=1)
+        self.conv_upsample4 = BasicConv2d(channel[2], channel[2], 3, padding=1)
+        self.conv_upsample5 = BasicConv2d(
+            channel[2] + channel[1], channel[2] + channel[1], 3, padding=1
+        )
+
+        self.conv_concat2 = BasicConv2d(
+            (channel[2] + channel[1]), (channel[2] + channel[1]), 3, padding=1
+        )
+        self.conv_concat3 = BasicConv2d(
+            (channel[0] + channel[1] + channel[2]),
+            (channel[0] + channel[1] + channel[2]),
+            3,
+            padding=1,
+        )
+
+        self.UAM = UnionAttentionModule(channel[0] + channel[1] + channel[2])
+
+    def forward(self, e4, e3, e2):
+        e4_1 = e4
+        e3_1 = self.conv_upsample1(self.upsample(e4)) * e3
+        e2_1 = (
+            self.conv_upsample2(self.upsample(self.upsample(e4)))
+            * self.conv_upsample3(self.upsample(e3))
+            * e2
+        )
+
+        e3_2 = torch.cat((e3_1, self.conv_upsample4(self.upsample(e4_1))), 1)
+        e3_2 = self.conv_concat2(e3_2)
+
+        e2_2 = torch.cat((e2_1, self.conv_upsample5(self.upsample(e3_2))), 1)
+        x = self.conv_concat3(e2_2)
+
+        output = self.UAM(x)
+
+        return output
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, e4, e3, e2) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, e4, e3, e2):
+    e4_1 = e4
+    e3_1 = self.conv_upsample1(self.upsample(e4)) * e3
+    e2_1 = (
+        self.conv_upsample2(self.upsample(self.upsample(e4)))
+        * self.conv_upsample3(self.upsample(e3))
+        * e2
+    )
+
+    e3_2 = torch.cat((e3_1, self.conv_upsample4(self.upsample(e4_1))), 1)
+    e3_2 = self.conv_concat2(e3_2)
+
+    e2_2 = torch.cat((e2_1, self.conv_upsample5(self.upsample(e3_2))), 1)
+    x = self.conv_concat3(e2_2)
+
+    output = self.UAM(x)
+
+    return output
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/tracerb7/conv_modules.html b/docs/api/carvekit/ml/arch/tracerb7/conv_modules.html new file mode 100644 index 0000000..08302f0 --- /dev/null +++ b/docs/api/carvekit/ml/arch/tracerb7/conv_modules.html @@ -0,0 +1,465 @@ + + + + + + +carvekit.ml.arch.tracerb7.conv_modules API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.tracerb7.conv_modules

+
+
+

Source url: https://github.com/Karel911/TRACER +Author: Min Seok Lee and Wooseok Shin +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/Karel911/TRACER
+Author: Min Seok Lee and Wooseok Shin
+License: Apache License 2.0
+"""
+import torch.nn as nn
+
+
+class BasicConv2d(nn.Module):
+    def __init__(
+        self,
+        in_channel,
+        out_channel,
+        kernel_size,
+        stride=(1, 1),
+        padding=(0, 0),
+        dilation=(1, 1),
+    ):
+        super(BasicConv2d, self).__init__()
+        self.conv = nn.Conv2d(
+            in_channel,
+            out_channel,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            bias=False,
+        )
+        self.bn = nn.BatchNorm2d(out_channel)
+        self.selu = nn.SELU()
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.selu(x)
+
+        return x
+
+
+class DWConv(nn.Module):
+    def __init__(self, in_channel, out_channel, kernel, dilation, padding):
+        super(DWConv, self).__init__()
+        self.out_channel = out_channel
+        self.DWConv = nn.Conv2d(
+            in_channel,
+            out_channel,
+            kernel_size=kernel,
+            padding=padding,
+            groups=in_channel,
+            dilation=dilation,
+            bias=False,
+        )
+        self.bn = nn.BatchNorm2d(out_channel)
+        self.selu = nn.SELU()
+
+    def forward(self, x):
+        x = self.DWConv(x)
+        out = self.selu(self.bn(x))
+
+        return out
+
+
+class DWSConv(nn.Module):
+    def __init__(self, in_channel, out_channel, kernel, padding, kernels_per_layer):
+        super(DWSConv, self).__init__()
+        self.out_channel = out_channel
+        self.DWConv = nn.Conv2d(
+            in_channel,
+            in_channel * kernels_per_layer,
+            kernel_size=kernel,
+            padding=padding,
+            groups=in_channel,
+            bias=False,
+        )
+        self.bn = nn.BatchNorm2d(in_channel * kernels_per_layer)
+        self.selu = nn.SELU()
+        self.PWConv = nn.Conv2d(
+            in_channel * kernels_per_layer, out_channel, kernel_size=1, bias=False
+        )
+        self.bn2 = nn.BatchNorm2d(out_channel)
+
+    def forward(self, x):
+        x = self.DWConv(x)
+        x = self.selu(self.bn(x))
+        out = self.PWConv(x)
+        out = self.selu(self.bn2(out))
+
+        return out
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class BasicConv2d +(in_channel, out_channel, kernel_size, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class BasicConv2d(nn.Module):
+    def __init__(
+        self,
+        in_channel,
+        out_channel,
+        kernel_size,
+        stride=(1, 1),
+        padding=(0, 0),
+        dilation=(1, 1),
+    ):
+        super(BasicConv2d, self).__init__()
+        self.conv = nn.Conv2d(
+            in_channel,
+            out_channel,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            bias=False,
+        )
+        self.bn = nn.BatchNorm2d(out_channel)
+        self.selu = nn.SELU()
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.selu(x)
+
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.conv(x)
+    x = self.bn(x)
+    x = self.selu(x)
+
+    return x
+
+
+
+
+
+class DWConv +(in_channel, out_channel, kernel, dilation, padding) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DWConv(nn.Module):
+    def __init__(self, in_channel, out_channel, kernel, dilation, padding):
+        super(DWConv, self).__init__()
+        self.out_channel = out_channel
+        self.DWConv = nn.Conv2d(
+            in_channel,
+            out_channel,
+            kernel_size=kernel,
+            padding=padding,
+            groups=in_channel,
+            dilation=dilation,
+            bias=False,
+        )
+        self.bn = nn.BatchNorm2d(out_channel)
+        self.selu = nn.SELU()
+
+    def forward(self, x):
+        x = self.DWConv(x)
+        out = self.selu(self.bn(x))
+
+        return out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.DWConv(x)
+    out = self.selu(self.bn(x))
+
+    return out
+
+
+
+
+
+class DWSConv +(in_channel, out_channel, kernel, padding, kernels_per_layer) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DWSConv(nn.Module):
+    def __init__(self, in_channel, out_channel, kernel, padding, kernels_per_layer):
+        super(DWSConv, self).__init__()
+        self.out_channel = out_channel
+        self.DWConv = nn.Conv2d(
+            in_channel,
+            in_channel * kernels_per_layer,
+            kernel_size=kernel,
+            padding=padding,
+            groups=in_channel,
+            bias=False,
+        )
+        self.bn = nn.BatchNorm2d(in_channel * kernels_per_layer)
+        self.selu = nn.SELU()
+        self.PWConv = nn.Conv2d(
+            in_channel * kernels_per_layer, out_channel, kernel_size=1, bias=False
+        )
+        self.bn2 = nn.BatchNorm2d(out_channel)
+
+    def forward(self, x):
+        x = self.DWConv(x)
+        x = self.selu(self.bn(x))
+        out = self.PWConv(x)
+        out = self.selu(self.bn2(out))
+
+        return out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.DWConv(x)
+    x = self.selu(self.bn(x))
+    out = self.PWConv(x)
+    out = self.selu(self.bn2(out))
+
+    return out
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/tracerb7/effi_utils.html b/docs/api/carvekit/ml/arch/tracerb7/effi_utils.html new file mode 100644 index 0000000..a94c44a --- /dev/null +++ b/docs/api/carvekit/ml/arch/tracerb7/effi_utils.html @@ -0,0 +1,2066 @@ + + + + + + +carvekit.ml.arch.tracerb7.effi_utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.tracerb7.effi_utils

+
+
+

Original author: lukemelas (github username) +Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +With adjustments and added comments by workingcoder (github username). +License: Apache License 2.0 +Reimplemented: Min Seok Lee and Wooseok Shin

+
+ +Expand source code + +
"""
+Original author: lukemelas (github username)
+Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
+With adjustments and added comments by workingcoder (github username).
+License: Apache License 2.0
+Reimplemented: Min Seok Lee and Wooseok Shin
+"""
+
+import collections
+import re
+from functools import partial
+
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+# Parameters for the entire model (stem, all blocks, and head)
+GlobalParams = collections.namedtuple(
+    "GlobalParams",
+    [
+        "width_coefficient",
+        "depth_coefficient",
+        "image_size",
+        "dropout_rate",
+        "num_classes",
+        "batch_norm_momentum",
+        "batch_norm_epsilon",
+        "drop_connect_rate",
+        "depth_divisor",
+        "min_depth",
+        "include_top",
+    ],
+)
+
+# Parameters for an individual model block
+BlockArgs = collections.namedtuple(
+    "BlockArgs",
+    [
+        "num_repeat",
+        "kernel_size",
+        "stride",
+        "expand_ratio",
+        "input_filters",
+        "output_filters",
+        "se_ratio",
+        "id_skip",
+    ],
+)
+
+# Set GlobalParams and BlockArgs's defaults
+GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
+BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
+
+
+# An ordinary implementation of Swish function
+class Swish(nn.Module):
+    def forward(self, x):
+        return x * torch.sigmoid(x)
+
+
+# A memory-efficient implementation of Swish function
+class SwishImplementation(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, i):
+        result = i * torch.sigmoid(i)
+        ctx.save_for_backward(i)
+        return result
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        i = ctx.saved_tensors[0]
+        sigmoid_i = torch.sigmoid(i)
+        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+
+class MemoryEfficientSwish(nn.Module):
+    def forward(self, x):
+        return SwishImplementation.apply(x)
+
+
+def round_filters(filters, global_params):
+    """Calculate and round number of filters based on width multiplier.
+       Use width_coefficient, depth_divisor and min_depth of global_params.
+
+    Args:
+        filters (int): Filters number to be calculated.
+        global_params (namedtuple): Global params of the model.
+
+    Returns:
+        new_filters: New filters number after calculating.
+    """
+    multiplier = global_params.width_coefficient
+    if not multiplier:
+        return filters
+    divisor = global_params.depth_divisor
+    min_depth = global_params.min_depth
+    filters *= multiplier
+    min_depth = min_depth or divisor  # pay attention to this line when using min_depth
+    # follow the formula transferred from official TensorFlow implementation
+    new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
+    if new_filters < 0.9 * filters:  # prevent rounding by more than 10%
+        new_filters += divisor
+    return int(new_filters)
+
+
+def round_repeats(repeats, global_params):
+    """Calculate module's repeat number of a block based on depth multiplier.
+       Use depth_coefficient of global_params.
+
+    Args:
+        repeats (int): num_repeat to be calculated.
+        global_params (namedtuple): Global params of the model.
+
+    Returns:
+        new repeat: New repeat number after calculating.
+    """
+    multiplier = global_params.depth_coefficient
+    if not multiplier:
+        return repeats
+    # follow the formula transferred from official TensorFlow implementation
+    return int(math.ceil(multiplier * repeats))
+
+
+def drop_connect(inputs, p, training):
+    """Drop connect.
+
+    Args:
+        input (tensor: BCWH): Input of this structure.
+        p (float: 0.0~1.0): Probability of drop connection.
+        training (bool): The running mode.
+
+    Returns:
+        output: Output after drop connection.
+    """
+    assert 0 <= p <= 1, "p must be in range of [0,1]"
+
+    if not training:
+        return inputs
+
+    batch_size = inputs.shape[0]
+    keep_prob = 1 - p
+
+    # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
+    random_tensor = keep_prob
+    random_tensor += torch.rand(
+        [batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
+    )
+    binary_tensor = torch.floor(random_tensor)
+
+    output = inputs / keep_prob * binary_tensor
+    return output
+
+
+def get_width_and_height_from_size(x):
+    """Obtain height and width from x.
+
+    Args:
+        x (int, tuple or list): Data size.
+
+    Returns:
+        size: A tuple or list (H,W).
+    """
+    if isinstance(x, int):
+        return x, x
+    if isinstance(x, list) or isinstance(x, tuple):
+        return x
+    else:
+        raise TypeError()
+
+
+def calculate_output_image_size(input_image_size, stride):
+    """Calculates the output image size when using Conv2dSamePadding with a stride.
+       Necessary for static padding. Thanks to mannatsingh for pointing this out.
+
+    Args:
+        input_image_size (int, tuple or list): Size of input image.
+        stride (int, tuple or list): Conv2d operation's stride.
+
+    Returns:
+        output_image_size: A list [H,W].
+    """
+    if input_image_size is None:
+        return None
+    image_height, image_width = get_width_and_height_from_size(input_image_size)
+    stride = stride if isinstance(stride, int) else stride[0]
+    image_height = int(math.ceil(image_height / stride))
+    image_width = int(math.ceil(image_width / stride))
+    return [image_height, image_width]
+
+
+# Note:
+# The following 'SamePadding' functions make output size equal ceil(input size/stride).
+# Only when stride equals 1, can the output size be the same as input size.
+# Don't be confused by their function names ! ! !
+
+
+def get_same_padding_conv2d(image_size=None):
+    """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
+       Static padding is necessary for ONNX exporting of models.
+
+    Args:
+        image_size (int or tuple): Size of the image.
+
+    Returns:
+        Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
+    """
+    if image_size is None:
+        return Conv2dDynamicSamePadding
+    else:
+        return partial(Conv2dStaticSamePadding, image_size=image_size)
+
+
+class Conv2dDynamicSamePadding(nn.Conv2d):
+    """2D Convolutions like TensorFlow, for a dynamic image size.
+    The padding is operated in forward function by calculating dynamically.
+    """
+
+    # Tips for 'SAME' mode padding.
+    #     Given the following:
+    #         i: width or height
+    #         s: stride
+    #         k: kernel size
+    #         d: dilation
+    #         p: padding
+    #     Output after Conv2d:
+    #         o = floor((i+p-((k-1)*d+1))/s+1)
+    # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
+    # => p = (i-1)*s+((k-1)*d+1)-i
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        dilation=1,
+        groups=1,
+        bias=True,
+    ):
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias
+        )
+        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
+
+    def forward(self, x):
+        ih, iw = x.size()[-2:]
+        kh, kw = self.weight.size()[-2:]
+        sh, sw = self.stride
+        oh, ow = math.ceil(ih / sh), math.ceil(
+            iw / sw
+        )  # change the output size according to stride ! ! !
+        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+        if pad_h > 0 or pad_w > 0:
+            x = F.pad(
+                x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
+            )
+        return F.conv2d(
+            x,
+            self.weight,
+            self.bias,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.groups,
+        )
+
+
+class Conv2dStaticSamePadding(nn.Conv2d):
+    """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
+    The padding mudule is calculated in construction function, then used in forward.
+    """
+
+    # With the same calculation as Conv2dDynamicSamePadding
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        image_size=None,
+        **kwargs
+    ):
+        super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
+        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
+
+        # Calculate padding based on image size and save it
+        assert image_size is not None
+        ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
+        kh, kw = self.weight.size()[-2:]
+        sh, sw = self.stride
+        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+        if pad_h > 0 or pad_w > 0:
+            self.static_padding = nn.ZeroPad2d(
+                (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
+            )
+        else:
+            self.static_padding = nn.Identity()
+
+    def forward(self, x):
+        x = self.static_padding(x)
+        x = F.conv2d(
+            x,
+            self.weight,
+            self.bias,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.groups,
+        )
+        return x
+
+
+def get_same_padding_maxPool2d(image_size=None):
+    """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
+       Static padding is necessary for ONNX exporting of models.
+
+    Args:
+        image_size (int or tuple): Size of the image.
+
+    Returns:
+        MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
+    """
+    if image_size is None:
+        return MaxPool2dDynamicSamePadding
+    else:
+        return partial(MaxPool2dStaticSamePadding, image_size=image_size)
+
+
+class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
+    """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
+    The padding is operated in forward function by calculating dynamically.
+    """
+
+    def __init__(
+        self,
+        kernel_size,
+        stride,
+        padding=0,
+        dilation=1,
+        return_indices=False,
+        ceil_mode=False,
+    ):
+        super().__init__(
+            kernel_size, stride, padding, dilation, return_indices, ceil_mode
+        )
+        self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
+        self.kernel_size = (
+            [self.kernel_size] * 2
+            if isinstance(self.kernel_size, int)
+            else self.kernel_size
+        )
+        self.dilation = (
+            [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
+        )
+
+    def forward(self, x):
+        ih, iw = x.size()[-2:]
+        kh, kw = self.kernel_size
+        sh, sw = self.stride
+        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+        if pad_h > 0 or pad_w > 0:
+            x = F.pad(
+                x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
+            )
+        return F.max_pool2d(
+            x,
+            self.kernel_size,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.ceil_mode,
+            self.return_indices,
+        )
+
+
+class MaxPool2dStaticSamePadding(nn.MaxPool2d):
+    """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
+    The padding mudule is calculated in construction function, then used in forward.
+    """
+
+    def __init__(self, kernel_size, stride, image_size=None, **kwargs):
+        super().__init__(kernel_size, stride, **kwargs)
+        self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
+        self.kernel_size = (
+            [self.kernel_size] * 2
+            if isinstance(self.kernel_size, int)
+            else self.kernel_size
+        )
+        self.dilation = (
+            [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
+        )
+
+        # Calculate padding based on image size and save it
+        assert image_size is not None
+        ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
+        kh, kw = self.kernel_size
+        sh, sw = self.stride
+        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+        if pad_h > 0 or pad_w > 0:
+            self.static_padding = nn.ZeroPad2d(
+                (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
+            )
+        else:
+            self.static_padding = nn.Identity()
+
+    def forward(self, x):
+        x = self.static_padding(x)
+        x = F.max_pool2d(
+            x,
+            self.kernel_size,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.ceil_mode,
+            self.return_indices,
+        )
+        return x
+
+
+class BlockDecoder(object):
+    """Block Decoder for readability,
+    straight from the official TensorFlow repository.
+    """
+
+    @staticmethod
+    def _decode_block_string(block_string):
+        """Get a block through a string notation of arguments.
+
+        Args:
+            block_string (str): A string notation of arguments.
+                                Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
+
+        Returns:
+            BlockArgs: The namedtuple defined at the top of this file.
+        """
+        assert isinstance(block_string, str)
+
+        ops = block_string.split("_")
+        options = {}
+        for op in ops:
+            splits = re.split(r"(\d.*)", op)
+            if len(splits) >= 2:
+                key, value = splits[:2]
+                options[key] = value
+
+        # Check stride
+        assert ("s" in options and len(options["s"]) == 1) or (
+            len(options["s"]) == 2 and options["s"][0] == options["s"][1]
+        )
+
+        return BlockArgs(
+            num_repeat=int(options["r"]),
+            kernel_size=int(options["k"]),
+            stride=[int(options["s"][0])],
+            expand_ratio=int(options["e"]),
+            input_filters=int(options["i"]),
+            output_filters=int(options["o"]),
+            se_ratio=float(options["se"]) if "se" in options else None,
+            id_skip=("noskip" not in block_string),
+        )
+
+    @staticmethod
+    def _encode_block_string(block):
+        """Encode a block to a string.
+
+        Args:
+            block (namedtuple): A BlockArgs type argument.
+
+        Returns:
+            block_string: A String form of BlockArgs.
+        """
+        args = [
+            "r%d" % block.num_repeat,
+            "k%d" % block.kernel_size,
+            "s%d%d" % (block.strides[0], block.strides[1]),
+            "e%s" % block.expand_ratio,
+            "i%d" % block.input_filters,
+            "o%d" % block.output_filters,
+        ]
+        if 0 < block.se_ratio <= 1:
+            args.append("se%s" % block.se_ratio)
+        if block.id_skip is False:
+            args.append("noskip")
+        return "_".join(args)
+
+    @staticmethod
+    def decode(string_list):
+        """Decode a list of string notations to specify blocks inside the network.
+
+        Args:
+            string_list (list[str]): A list of strings, each string is a notation of block.
+
+        Returns:
+            blocks_args: A list of BlockArgs namedtuples of block args.
+        """
+        assert isinstance(string_list, list)
+        blocks_args = []
+        for block_string in string_list:
+            blocks_args.append(BlockDecoder._decode_block_string(block_string))
+        return blocks_args
+
+    @staticmethod
+    def encode(blocks_args):
+        """Encode a list of BlockArgs to a list of strings.
+
+        Args:
+            blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
+
+        Returns:
+            block_strings: A list of strings, each string is a notation of block.
+        """
+        block_strings = []
+        for block in blocks_args:
+            block_strings.append(BlockDecoder._encode_block_string(block))
+        return block_strings
+
+
+def create_block_args(
+    width_coefficient=None,
+    depth_coefficient=None,
+    image_size=None,
+    dropout_rate=0.2,
+    drop_connect_rate=0.2,
+    num_classes=1000,
+    include_top=True,
+):
+    """Create BlockArgs and GlobalParams for efficientnet model.
+
+    Args:
+        width_coefficient (float)
+        depth_coefficient (float)
+        image_size (int)
+        dropout_rate (float)
+        drop_connect_rate (float)
+        num_classes (int)
+
+        Meaning as the name suggests.
+
+    Returns:
+        blocks_args, global_params.
+    """
+
+    # Blocks args for the whole model(efficientnet-b0 by default)
+    # It will be modified in the construction of EfficientNet Class according to model
+    blocks_args = [
+        "r1_k3_s11_e1_i32_o16_se0.25",
+        "r2_k3_s22_e6_i16_o24_se0.25",
+        "r2_k5_s22_e6_i24_o40_se0.25",
+        "r3_k3_s22_e6_i40_o80_se0.25",
+        "r3_k5_s11_e6_i80_o112_se0.25",
+        "r4_k5_s22_e6_i112_o192_se0.25",
+        "r1_k3_s11_e6_i192_o320_se0.25",
+    ]
+    blocks_args = BlockDecoder.decode(blocks_args)
+
+    global_params = GlobalParams(
+        width_coefficient=width_coefficient,
+        depth_coefficient=depth_coefficient,
+        image_size=image_size,
+        dropout_rate=dropout_rate,
+        num_classes=num_classes,
+        batch_norm_momentum=0.99,
+        batch_norm_epsilon=1e-3,
+        drop_connect_rate=drop_connect_rate,
+        depth_divisor=8,
+        min_depth=None,
+        include_top=include_top,
+    )
+
+    return blocks_args, global_params
+
+
+
+
+
+
+
+

Functions

+
+
+def calculate_output_image_size(input_image_size, stride) +
+
+

Calculates the output image size when using Conv2dSamePadding with a stride. +Necessary for static padding. Thanks to mannatsingh for pointing this out.

+

Args

+
+
input_image_size : int, tuple or list
+
Size of input image.
+
stride : int, tuple or list
+
Conv2d operation's stride.
+
+

Returns

+
+
output_image_size
+
A list [H,W].
+
+
+ +Expand source code + +
def calculate_output_image_size(input_image_size, stride):
+    """Calculates the output image size when using Conv2dSamePadding with a stride.
+       Necessary for static padding. Thanks to mannatsingh for pointing this out.
+
+    Args:
+        input_image_size (int, tuple or list): Size of input image.
+        stride (int, tuple or list): Conv2d operation's stride.
+
+    Returns:
+        output_image_size: A list [H,W].
+    """
+    if input_image_size is None:
+        return None
+    image_height, image_width = get_width_and_height_from_size(input_image_size)
+    stride = stride if isinstance(stride, int) else stride[0]
+    image_height = int(math.ceil(image_height / stride))
+    image_width = int(math.ceil(image_width / stride))
+    return [image_height, image_width]
+
+
+
+def create_block_args(width_coefficient=None, depth_coefficient=None, image_size=None, dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True) +
+
+

Create BlockArgs and GlobalParams for efficientnet model.

+

Args

+

width_coefficient (float) +depth_coefficient (float) +image_size (int) +dropout_rate (float) +drop_connect_rate (float) +num_classes (int)

+

Meaning as the name suggests.

+

Returns

+

blocks_args, global_params.

+
+ +Expand source code + +
def create_block_args(
+    width_coefficient=None,
+    depth_coefficient=None,
+    image_size=None,
+    dropout_rate=0.2,
+    drop_connect_rate=0.2,
+    num_classes=1000,
+    include_top=True,
+):
+    """Create BlockArgs and GlobalParams for efficientnet model.
+
+    Args:
+        width_coefficient (float)
+        depth_coefficient (float)
+        image_size (int)
+        dropout_rate (float)
+        drop_connect_rate (float)
+        num_classes (int)
+
+        Meaning as the name suggests.
+
+    Returns:
+        blocks_args, global_params.
+    """
+
+    # Blocks args for the whole model(efficientnet-b0 by default)
+    # It will be modified in the construction of EfficientNet Class according to model
+    blocks_args = [
+        "r1_k3_s11_e1_i32_o16_se0.25",
+        "r2_k3_s22_e6_i16_o24_se0.25",
+        "r2_k5_s22_e6_i24_o40_se0.25",
+        "r3_k3_s22_e6_i40_o80_se0.25",
+        "r3_k5_s11_e6_i80_o112_se0.25",
+        "r4_k5_s22_e6_i112_o192_se0.25",
+        "r1_k3_s11_e6_i192_o320_se0.25",
+    ]
+    blocks_args = BlockDecoder.decode(blocks_args)
+
+    global_params = GlobalParams(
+        width_coefficient=width_coefficient,
+        depth_coefficient=depth_coefficient,
+        image_size=image_size,
+        dropout_rate=dropout_rate,
+        num_classes=num_classes,
+        batch_norm_momentum=0.99,
+        batch_norm_epsilon=1e-3,
+        drop_connect_rate=drop_connect_rate,
+        depth_divisor=8,
+        min_depth=None,
+        include_top=include_top,
+    )
+
+    return blocks_args, global_params
+
+
+
+def drop_connect(inputs, p, training) +
+
+

Drop connect.

+

Args

+
+
input (tensor: BCWH): Input of this structure.
+
p (float: 0.0~1.0): Probability of drop connection.
+
training : bool
+
The running mode.
+
+

Returns

+
+
output
+
Output after drop connection.
+
+
+ +Expand source code + +
def drop_connect(inputs, p, training):
+    """Drop connect.
+
+    Args:
+        input (tensor: BCWH): Input of this structure.
+        p (float: 0.0~1.0): Probability of drop connection.
+        training (bool): The running mode.
+
+    Returns:
+        output: Output after drop connection.
+    """
+    assert 0 <= p <= 1, "p must be in range of [0,1]"
+
+    if not training:
+        return inputs
+
+    batch_size = inputs.shape[0]
+    keep_prob = 1 - p
+
+    # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
+    random_tensor = keep_prob
+    random_tensor += torch.rand(
+        [batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
+    )
+    binary_tensor = torch.floor(random_tensor)
+
+    output = inputs / keep_prob * binary_tensor
+    return output
+
+
+
+def get_same_padding_conv2d(image_size=None) +
+
+

Chooses static padding if you have specified an image size, and dynamic padding otherwise. +Static padding is necessary for ONNX exporting of models.

+

Args

+
+
image_size : int or tuple
+
Size of the image.
+
+

Returns

+

Conv2dDynamicSamePadding or Conv2dStaticSamePadding.

+
+ +Expand source code + +
def get_same_padding_conv2d(image_size=None):
+    """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
+       Static padding is necessary for ONNX exporting of models.
+
+    Args:
+        image_size (int or tuple): Size of the image.
+
+    Returns:
+        Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
+    """
+    if image_size is None:
+        return Conv2dDynamicSamePadding
+    else:
+        return partial(Conv2dStaticSamePadding, image_size=image_size)
+
+
+
+def get_same_padding_maxPool2d(image_size=None) +
+
+

Chooses static padding if you have specified an image size, and dynamic padding otherwise. +Static padding is necessary for ONNX exporting of models.

+

Args

+
+
image_size : int or tuple
+
Size of the image.
+
+

Returns

+

MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.

+
+ +Expand source code + +
def get_same_padding_maxPool2d(image_size=None):
+    """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
+       Static padding is necessary for ONNX exporting of models.
+
+    Args:
+        image_size (int or tuple): Size of the image.
+
+    Returns:
+        MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
+    """
+    if image_size is None:
+        return MaxPool2dDynamicSamePadding
+    else:
+        return partial(MaxPool2dStaticSamePadding, image_size=image_size)
+
+
+
+def get_width_and_height_from_size(x) +
+
+

Obtain height and width from x.

+

Args

+
+
x : int, tuple or list
+
Data size.
+
+

Returns

+
+
size
+
A tuple or list (H,W).
+
+
+ +Expand source code + +
def get_width_and_height_from_size(x):
+    """Obtain height and width from x.
+
+    Args:
+        x (int, tuple or list): Data size.
+
+    Returns:
+        size: A tuple or list (H,W).
+    """
+    if isinstance(x, int):
+        return x, x
+    if isinstance(x, list) or isinstance(x, tuple):
+        return x
+    else:
+        raise TypeError()
+
+
+
+def round_filters(filters, global_params) +
+
+

Calculate and round number of filters based on width multiplier. +Use width_coefficient, depth_divisor and min_depth of global_params.

+

Args

+
+
filters : int
+
Filters number to be calculated.
+
global_params : namedtuple
+
Global params of the model.
+
+

Returns

+
+
new_filters
+
New filters number after calculating.
+
+
+ +Expand source code + +
def round_filters(filters, global_params):
+    """Calculate and round number of filters based on width multiplier.
+       Use width_coefficient, depth_divisor and min_depth of global_params.
+
+    Args:
+        filters (int): Filters number to be calculated.
+        global_params (namedtuple): Global params of the model.
+
+    Returns:
+        new_filters: New filters number after calculating.
+    """
+    multiplier = global_params.width_coefficient
+    if not multiplier:
+        return filters
+    divisor = global_params.depth_divisor
+    min_depth = global_params.min_depth
+    filters *= multiplier
+    min_depth = min_depth or divisor  # pay attention to this line when using min_depth
+    # follow the formula transferred from official TensorFlow implementation
+    new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
+    if new_filters < 0.9 * filters:  # prevent rounding by more than 10%
+        new_filters += divisor
+    return int(new_filters)
+
+
+
+def round_repeats(repeats, global_params) +
+
+

Calculate module's repeat number of a block based on depth multiplier. +Use depth_coefficient of global_params.

+

Args

+
+
repeats : int
+
num_repeat to be calculated.
+
global_params : namedtuple
+
Global params of the model.
+
+

Returns

+
+
new repeat
+
New repeat number after calculating.
+
+
+ +Expand source code + +
def round_repeats(repeats, global_params):
+    """Calculate module's repeat number of a block based on depth multiplier.
+       Use depth_coefficient of global_params.
+
+    Args:
+        repeats (int): num_repeat to be calculated.
+        global_params (namedtuple): Global params of the model.
+
+    Returns:
+        new repeat: New repeat number after calculating.
+    """
+    multiplier = global_params.depth_coefficient
+    if not multiplier:
+        return repeats
+    # follow the formula transferred from official TensorFlow implementation
+    return int(math.ceil(multiplier * repeats))
+
+
+
+
+
+

Classes

+
+
+class BlockArgs +(num_repeat=None, kernel_size=None, stride=None, expand_ratio=None, input_filters=None, output_filters=None, se_ratio=None, id_skip=None) +
+
+

BlockArgs(num_repeat, kernel_size, stride, expand_ratio, input_filters, output_filters, se_ratio, id_skip)

+

Ancestors

+
    +
  • builtins.tuple
  • +
+

Instance variables

+
+
var expand_ratio
+
+

Alias for field number 3

+
+
var id_skip
+
+

Alias for field number 7

+
+
var input_filters
+
+

Alias for field number 4

+
+
var kernel_size
+
+

Alias for field number 1

+
+
var num_repeat
+
+

Alias for field number 0

+
+
var output_filters
+
+

Alias for field number 5

+
+
var se_ratio
+
+

Alias for field number 6

+
+
var stride
+
+

Alias for field number 2

+
+
+
+
+class BlockDecoder +
+
+

Block Decoder for readability, +straight from the official TensorFlow repository.

+
+ +Expand source code + +
class BlockDecoder(object):
+    """Block Decoder for readability,
+    straight from the official TensorFlow repository.
+    """
+
+    @staticmethod
+    def _decode_block_string(block_string):
+        """Get a block through a string notation of arguments.
+
+        Args:
+            block_string (str): A string notation of arguments.
+                                Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
+
+        Returns:
+            BlockArgs: The namedtuple defined at the top of this file.
+        """
+        assert isinstance(block_string, str)
+
+        ops = block_string.split("_")
+        options = {}
+        for op in ops:
+            splits = re.split(r"(\d.*)", op)
+            if len(splits) >= 2:
+                key, value = splits[:2]
+                options[key] = value
+
+        # Check stride
+        assert ("s" in options and len(options["s"]) == 1) or (
+            len(options["s"]) == 2 and options["s"][0] == options["s"][1]
+        )
+
+        return BlockArgs(
+            num_repeat=int(options["r"]),
+            kernel_size=int(options["k"]),
+            stride=[int(options["s"][0])],
+            expand_ratio=int(options["e"]),
+            input_filters=int(options["i"]),
+            output_filters=int(options["o"]),
+            se_ratio=float(options["se"]) if "se" in options else None,
+            id_skip=("noskip" not in block_string),
+        )
+
+    @staticmethod
+    def _encode_block_string(block):
+        """Encode a block to a string.
+
+        Args:
+            block (namedtuple): A BlockArgs type argument.
+
+        Returns:
+            block_string: A String form of BlockArgs.
+        """
+        args = [
+            "r%d" % block.num_repeat,
+            "k%d" % block.kernel_size,
+            "s%d%d" % (block.strides[0], block.strides[1]),
+            "e%s" % block.expand_ratio,
+            "i%d" % block.input_filters,
+            "o%d" % block.output_filters,
+        ]
+        if 0 < block.se_ratio <= 1:
+            args.append("se%s" % block.se_ratio)
+        if block.id_skip is False:
+            args.append("noskip")
+        return "_".join(args)
+
+    @staticmethod
+    def decode(string_list):
+        """Decode a list of string notations to specify blocks inside the network.
+
+        Args:
+            string_list (list[str]): A list of strings, each string is a notation of block.
+
+        Returns:
+            blocks_args: A list of BlockArgs namedtuples of block args.
+        """
+        assert isinstance(string_list, list)
+        blocks_args = []
+        for block_string in string_list:
+            blocks_args.append(BlockDecoder._decode_block_string(block_string))
+        return blocks_args
+
+    @staticmethod
+    def encode(blocks_args):
+        """Encode a list of BlockArgs to a list of strings.
+
+        Args:
+            blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
+
+        Returns:
+            block_strings: A list of strings, each string is a notation of block.
+        """
+        block_strings = []
+        for block in blocks_args:
+            block_strings.append(BlockDecoder._encode_block_string(block))
+        return block_strings
+
+

Static methods

+
+
+def decode(string_list) +
+
+

Decode a list of string notations to specify blocks inside the network.

+

Args

+
+
string_list : list[str]
+
A list of strings, each string is a notation of block.
+
+

Returns

+
+
blocks_args
+
A list of BlockArgs namedtuples of block args.
+
+
+ +Expand source code + +
@staticmethod
+def decode(string_list):
+    """Decode a list of string notations to specify blocks inside the network.
+
+    Args:
+        string_list (list[str]): A list of strings, each string is a notation of block.
+
+    Returns:
+        blocks_args: A list of BlockArgs namedtuples of block args.
+    """
+    assert isinstance(string_list, list)
+    blocks_args = []
+    for block_string in string_list:
+        blocks_args.append(BlockDecoder._decode_block_string(block_string))
+    return blocks_args
+
+
+
+def encode(blocks_args) +
+
+

Encode a list of BlockArgs to a list of strings.

+

Args

+
+
blocks_args : list[namedtuples]
+
A list of BlockArgs namedtuples of block args.
+
+

Returns

+
+
block_strings
+
A list of strings, each string is a notation of block.
+
+
+ +Expand source code + +
@staticmethod
+def encode(blocks_args):
+    """Encode a list of BlockArgs to a list of strings.
+
+    Args:
+        blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
+
+    Returns:
+        block_strings: A list of strings, each string is a notation of block.
+    """
+    block_strings = []
+    for block in blocks_args:
+        block_strings.append(BlockDecoder._encode_block_string(block))
+    return block_strings
+
+
+
+
+
+class Conv2dDynamicSamePadding +(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True) +
+
+

2D Convolutions like TensorFlow, for a dynamic image size. +The padding is operated in forward function by calculating dynamically.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Conv2dDynamicSamePadding(nn.Conv2d):
+    """2D Convolutions like TensorFlow, for a dynamic image size.
+    The padding is operated in forward function by calculating dynamically.
+    """
+
+    # Tips for 'SAME' mode padding.
+    #     Given the following:
+    #         i: width or height
+    #         s: stride
+    #         k: kernel size
+    #         d: dilation
+    #         p: padding
+    #     Output after Conv2d:
+    #         o = floor((i+p-((k-1)*d+1))/s+1)
+    # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
+    # => p = (i-1)*s+((k-1)*d+1)-i
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        dilation=1,
+        groups=1,
+        bias=True,
+    ):
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias
+        )
+        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
+
+    def forward(self, x):
+        ih, iw = x.size()[-2:]
+        kh, kw = self.weight.size()[-2:]
+        sh, sw = self.stride
+        oh, ow = math.ceil(ih / sh), math.ceil(
+            iw / sw
+        )  # change the output size according to stride ! ! !
+        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+        if pad_h > 0 or pad_w > 0:
+            x = F.pad(
+                x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
+            )
+        return F.conv2d(
+            x,
+            self.weight,
+            self.bias,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.groups,
+        )
+
+

Ancestors

+
    +
  • torch.nn.modules.conv.Conv2d
  • +
  • torch.nn.modules.conv._ConvNd
  • +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    ih, iw = x.size()[-2:]
+    kh, kw = self.weight.size()[-2:]
+    sh, sw = self.stride
+    oh, ow = math.ceil(ih / sh), math.ceil(
+        iw / sw
+    )  # change the output size according to stride ! ! !
+    pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+    pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+    if pad_h > 0 or pad_w > 0:
+        x = F.pad(
+            x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
+        )
+    return F.conv2d(
+        x,
+        self.weight,
+        self.bias,
+        self.stride,
+        self.padding,
+        self.dilation,
+        self.groups,
+    )
+
+
+
+
+
+class Conv2dStaticSamePadding +(in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs) +
+
+

2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. +The padding mudule is calculated in construction function, then used in forward.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Conv2dStaticSamePadding(nn.Conv2d):
+    """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
+    The padding mudule is calculated in construction function, then used in forward.
+    """
+
+    # With the same calculation as Conv2dDynamicSamePadding
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        image_size=None,
+        **kwargs
+    ):
+        super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
+        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
+
+        # Calculate padding based on image size and save it
+        assert image_size is not None
+        ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
+        kh, kw = self.weight.size()[-2:]
+        sh, sw = self.stride
+        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+        if pad_h > 0 or pad_w > 0:
+            self.static_padding = nn.ZeroPad2d(
+                (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
+            )
+        else:
+            self.static_padding = nn.Identity()
+
+    def forward(self, x):
+        x = self.static_padding(x)
+        x = F.conv2d(
+            x,
+            self.weight,
+            self.bias,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.groups,
+        )
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.conv.Conv2d
  • +
  • torch.nn.modules.conv._ConvNd
  • +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.static_padding(x)
+    x = F.conv2d(
+        x,
+        self.weight,
+        self.bias,
+        self.stride,
+        self.padding,
+        self.dilation,
+        self.groups,
+    )
+    return x
+
+
+
+
+
+class GlobalParams +(width_coefficient=None, depth_coefficient=None, image_size=None, dropout_rate=None, num_classes=None, batch_norm_momentum=None, batch_norm_epsilon=None, drop_connect_rate=None, depth_divisor=None, min_depth=None, include_top=None) +
+
+

GlobalParams(width_coefficient, depth_coefficient, image_size, dropout_rate, num_classes, batch_norm_momentum, batch_norm_epsilon, drop_connect_rate, depth_divisor, min_depth, include_top)

+

Ancestors

+
    +
  • builtins.tuple
  • +
+

Instance variables

+
+
var batch_norm_epsilon
+
+

Alias for field number 6

+
+
var batch_norm_momentum
+
+

Alias for field number 5

+
+
var depth_coefficient
+
+

Alias for field number 1

+
+
var depth_divisor
+
+

Alias for field number 8

+
+
var drop_connect_rate
+
+

Alias for field number 7

+
+
var dropout_rate
+
+

Alias for field number 3

+
+
var image_size
+
+

Alias for field number 2

+
+
var include_top
+
+

Alias for field number 10

+
+
var min_depth
+
+

Alias for field number 9

+
+
var num_classes
+
+

Alias for field number 4

+
+
var width_coefficient
+
+

Alias for field number 0

+
+
+
+
+class MaxPool2dDynamicSamePadding +(kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False) +
+
+

2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. +The padding is operated in forward function by calculating dynamically.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
+    """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
+    The padding is operated in forward function by calculating dynamically.
+    """
+
+    def __init__(
+        self,
+        kernel_size,
+        stride,
+        padding=0,
+        dilation=1,
+        return_indices=False,
+        ceil_mode=False,
+    ):
+        super().__init__(
+            kernel_size, stride, padding, dilation, return_indices, ceil_mode
+        )
+        self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
+        self.kernel_size = (
+            [self.kernel_size] * 2
+            if isinstance(self.kernel_size, int)
+            else self.kernel_size
+        )
+        self.dilation = (
+            [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
+        )
+
+    def forward(self, x):
+        ih, iw = x.size()[-2:]
+        kh, kw = self.kernel_size
+        sh, sw = self.stride
+        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+        if pad_h > 0 or pad_w > 0:
+            x = F.pad(
+                x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
+            )
+        return F.max_pool2d(
+            x,
+            self.kernel_size,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.ceil_mode,
+            self.return_indices,
+        )
+
+

Ancestors

+
    +
  • torch.nn.modules.pooling.MaxPool2d
  • +
  • torch.nn.modules.pooling._MaxPoolNd
  • +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    ih, iw = x.size()[-2:]
+    kh, kw = self.kernel_size
+    sh, sw = self.stride
+    oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+    pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+    pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+    if pad_h > 0 or pad_w > 0:
+        x = F.pad(
+            x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
+        )
+    return F.max_pool2d(
+        x,
+        self.kernel_size,
+        self.stride,
+        self.padding,
+        self.dilation,
+        self.ceil_mode,
+        self.return_indices,
+    )
+
+
+
+
+
+class MaxPool2dStaticSamePadding +(kernel_size, stride, image_size=None, **kwargs) +
+
+

2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. +The padding mudule is calculated in construction function, then used in forward.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class MaxPool2dStaticSamePadding(nn.MaxPool2d):
+    """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
+    The padding mudule is calculated in construction function, then used in forward.
+    """
+
+    def __init__(self, kernel_size, stride, image_size=None, **kwargs):
+        super().__init__(kernel_size, stride, **kwargs)
+        self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
+        self.kernel_size = (
+            [self.kernel_size] * 2
+            if isinstance(self.kernel_size, int)
+            else self.kernel_size
+        )
+        self.dilation = (
+            [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
+        )
+
+        # Calculate padding based on image size and save it
+        assert image_size is not None
+        ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
+        kh, kw = self.kernel_size
+        sh, sw = self.stride
+        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+        if pad_h > 0 or pad_w > 0:
+            self.static_padding = nn.ZeroPad2d(
+                (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
+            )
+        else:
+            self.static_padding = nn.Identity()
+
+    def forward(self, x):
+        x = self.static_padding(x)
+        x = F.max_pool2d(
+            x,
+            self.kernel_size,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.ceil_mode,
+            self.return_indices,
+        )
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.pooling.MaxPool2d
  • +
  • torch.nn.modules.pooling._MaxPoolNd
  • +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.static_padding(x)
+    x = F.max_pool2d(
+        x,
+        self.kernel_size,
+        self.stride,
+        self.padding,
+        self.dilation,
+        self.ceil_mode,
+        self.return_indices,
+    )
+    return x
+
+
+
+
+
+class MemoryEfficientSwish +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class MemoryEfficientSwish(nn.Module):
+    def forward(self, x):
+        return SwishImplementation.apply(x)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    return SwishImplementation.apply(x)
+
+
+
+
+
+class Swish +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Swish(nn.Module):
+    def forward(self, x):
+        return x * torch.sigmoid(x)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    return x * torch.sigmoid(x)
+
+
+
+
+
+class SwishImplementation +(*args, **kwargs) +
+
+

Base class to create custom autograd.Function

+

To create a custom autograd.Function, subclass this class and implement +the :meth:forward and :meth:backward static methods. Then, to use your custom +op in the forward pass, call the class method apply. Do not call +:meth:forward directly.

+

To ensure correctness and best performance, make sure you are calling the +correct methods on ctx and validating your backward function using +:func:torch.autograd.gradcheck.

+

See :ref:extending-autograd for more details on how to use this class.

+

Examples::

+
>>> class Exp(Function):
+>>>     @staticmethod
+>>>     def forward(ctx, i):
+>>>         result = i.exp()
+>>>         ctx.save_for_backward(result)
+>>>         return result
+>>>
+>>>     @staticmethod
+>>>     def backward(ctx, grad_output):
+>>>         result, = ctx.saved_tensors
+>>>         return grad_output * result
+>>>
+>>> # Use it by calling the apply method:
+>>> output = Exp.apply(input)
+
+
+ +Expand source code + +
class SwishImplementation(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, i):
+        result = i * torch.sigmoid(i)
+        ctx.save_for_backward(i)
+        return result
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        i = ctx.saved_tensors[0]
+        sigmoid_i = torch.sigmoid(i)
+        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+

Ancestors

+
    +
  • torch.autograd.function.Function
  • +
  • torch._C._FunctionBase
  • +
  • torch.autograd.function.FunctionCtx
  • +
  • torch.autograd.function._HookMixin
  • +
+

Static methods

+
+
+def backward(ctx, grad_output) +
+
+

Defines a formula for differentiating the operation with backward mode +automatic differentiation (alias to the vjp function).

+

This function is to be overridden by all subclasses.

+

It must accept a context :attr:ctx as the first argument, followed by +as many outputs as the :func:forward returned (None will be passed in +for non tensor outputs of the forward function), +and it should return as many tensors, as there were inputs to +:func:forward. Each argument is the gradient w.r.t the given output, +and each returned value should be the gradient w.r.t. the +corresponding input. If an input is not a Tensor or is a Tensor not +requiring grads, you can just pass None as a gradient for that input.

+

The context can be used to retrieve tensors saved during the forward +pass. It also has an attribute :attr:ctx.needs_input_grad as a tuple +of booleans representing whether each input needs gradient. E.g., +:func:backward will have ctx.needs_input_grad[0] = True if the +first input to :func:forward needs gradient computated w.r.t. the +output.

+
+ +Expand source code + +
@staticmethod
+def backward(ctx, grad_output):
+    i = ctx.saved_tensors[0]
+    sigmoid_i = torch.sigmoid(i)
+    return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+
+
+def forward(ctx, i) +
+
+

Performs the operation.

+

This function is to be overridden by all subclasses.

+

It must accept a context ctx as the first argument, followed by any +number of arguments (tensors or other types).

+

The context can be used to store arbitrary data that can be then +retrieved during the backward pass. Tensors should not be stored +directly on ctx (though this is not currently enforced for +backward compatibility). Instead, tensors should be saved either with +:func:ctx.save_for_backward if they are intended to be used in +backward (equivalently, vjp) or :func:ctx.save_for_forward +if they are intended to be used for in jvp.

+
+ +Expand source code + +
@staticmethod
+def forward(ctx, i):
+    result = i * torch.sigmoid(i)
+    ctx.save_for_backward(i)
+    return result
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/tracerb7/efficientnet.html b/docs/api/carvekit/ml/arch/tracerb7/efficientnet.html new file mode 100644 index 0000000..192582a --- /dev/null +++ b/docs/api/carvekit/ml/arch/tracerb7/efficientnet.html @@ -0,0 +1,1077 @@ + + + + + + +carvekit.ml.arch.tracerb7.efficientnet API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.tracerb7.efficientnet

+
+
+

Source url: https://github.com/lukemelas/EfficientNet-PyTorch +Modified by Min Seok Lee, Wooseok Shin, Nikita Selin +License: Apache License 2.0

+

Changes

+
    +
  • Added support for extracting edge features
  • +
  • Added support for extracting object features at different levels
  • +
  • Refactored the code
  • +
+
+ +Expand source code + +
"""
+Source url: https://github.com/lukemelas/EfficientNet-PyTorch
+Modified by Min Seok Lee, Wooseok Shin, Nikita Selin
+License: Apache License 2.0
+Changes:
+    - Added support for extracting edge features
+    - Added support for extracting object features at different levels
+    - Refactored the code
+"""
+from typing import Any, List
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from carvekit.ml.arch.tracerb7.effi_utils import (
+    get_same_padding_conv2d,
+    calculate_output_image_size,
+    MemoryEfficientSwish,
+    drop_connect,
+    round_filters,
+    round_repeats,
+    Swish,
+    create_block_args,
+)
+
+
+class MBConvBlock(nn.Module):
+    """Mobile Inverted Residual Bottleneck Block.
+
+    Args:
+        block_args (namedtuple): BlockArgs, defined in utils.py.
+        global_params (namedtuple): GlobalParam, defined in utils.py.
+        image_size (tuple or list): [image_height, image_width].
+
+    References:
+        [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
+        [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
+        [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
+    """
+
+    def __init__(self, block_args, global_params, image_size=None):
+        super().__init__()
+        self._block_args = block_args
+        self._bn_mom = (
+            1 - global_params.batch_norm_momentum
+        )  # pytorch's difference from tensorflow
+        self._bn_eps = global_params.batch_norm_epsilon
+        self.has_se = (self._block_args.se_ratio is not None) and (
+            0 < self._block_args.se_ratio <= 1
+        )
+        self.id_skip = (
+            block_args.id_skip
+        )  # whether to use skip connection and drop connect
+
+        # Expansion phase (Inverted Bottleneck)
+        inp = self._block_args.input_filters  # number of input channels
+        oup = (
+            self._block_args.input_filters * self._block_args.expand_ratio
+        )  # number of output channels
+        if self._block_args.expand_ratio != 1:
+            Conv2d = get_same_padding_conv2d(image_size=image_size)
+            self._expand_conv = Conv2d(
+                in_channels=inp, out_channels=oup, kernel_size=1, bias=False
+            )
+            self._bn0 = nn.BatchNorm2d(
+                num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
+            )
+            # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
+
+        # Depthwise convolution phase
+        k = self._block_args.kernel_size
+        s = self._block_args.stride
+        Conv2d = get_same_padding_conv2d(image_size=image_size)
+        self._depthwise_conv = Conv2d(
+            in_channels=oup,
+            out_channels=oup,
+            groups=oup,  # groups makes it depthwise
+            kernel_size=k,
+            stride=s,
+            bias=False,
+        )
+        self._bn1 = nn.BatchNorm2d(
+            num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
+        )
+        image_size = calculate_output_image_size(image_size, s)
+
+        # Squeeze and Excitation layer, if desired
+        if self.has_se:
+            Conv2d = get_same_padding_conv2d(image_size=(1, 1))
+            num_squeezed_channels = max(
+                1, int(self._block_args.input_filters * self._block_args.se_ratio)
+            )
+            self._se_reduce = Conv2d(
+                in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1
+            )
+            self._se_expand = Conv2d(
+                in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1
+            )
+
+        # Pointwise convolution phase
+        final_oup = self._block_args.output_filters
+        Conv2d = get_same_padding_conv2d(image_size=image_size)
+        self._project_conv = Conv2d(
+            in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
+        )
+        self._bn2 = nn.BatchNorm2d(
+            num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
+        )
+        self._swish = MemoryEfficientSwish()
+
+    def forward(self, inputs, drop_connect_rate=None):
+        """MBConvBlock's forward function.
+
+        Args:
+            inputs (tensor): Input tensor.
+            drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
+
+        Returns:
+            Output of this block after processing.
+        """
+
+        # Expansion and Depthwise Convolution
+        x = inputs
+        if self._block_args.expand_ratio != 1:
+            x = self._expand_conv(inputs)
+            x = self._bn0(x)
+            x = self._swish(x)
+
+        x = self._depthwise_conv(x)
+        x = self._bn1(x)
+        x = self._swish(x)
+
+        # Squeeze and Excitation
+        if self.has_se:
+            x_squeezed = F.adaptive_avg_pool2d(x, 1)
+            x_squeezed = self._se_reduce(x_squeezed)
+            x_squeezed = self._swish(x_squeezed)
+            x_squeezed = self._se_expand(x_squeezed)
+            x = torch.sigmoid(x_squeezed) * x
+
+        # Pointwise Convolution
+        x = self._project_conv(x)
+        x = self._bn2(x)
+
+        # Skip connection and drop connect
+        input_filters, output_filters = (
+            self._block_args.input_filters,
+            self._block_args.output_filters,
+        )
+        if (
+            self.id_skip
+            and self._block_args.stride == 1
+            and input_filters == output_filters
+        ):
+            # The combination of skip connection and drop connect brings about stochastic depth.
+            if drop_connect_rate:
+                x = drop_connect(x, p=drop_connect_rate, training=self.training)
+            x = x + inputs  # skip connection
+        return x
+
+    def set_swish(self, memory_efficient=True):
+        """Sets swish function as memory efficient (for training) or standard (for export).
+
+        Args:
+            memory_efficient (bool): Whether to use memory-efficient version of swish.
+        """
+        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+
+
+class EfficientNet(nn.Module):
+    def __init__(self, blocks_args=None, global_params=None):
+        super().__init__()
+        assert isinstance(blocks_args, list), "blocks_args should be a list"
+        assert len(blocks_args) > 0, "block args must be greater than 0"
+        self._global_params = global_params
+        self._blocks_args = blocks_args
+
+        # Batch norm parameters
+        bn_mom = 1 - self._global_params.batch_norm_momentum
+        bn_eps = self._global_params.batch_norm_epsilon
+
+        # Get stem static or dynamic convolution depending on image size
+        image_size = global_params.image_size
+        Conv2d = get_same_padding_conv2d(image_size=image_size)
+
+        # Stem
+        in_channels = 3  # rgb
+        out_channels = round_filters(
+            32, self._global_params
+        )  # number of output channels
+        self._conv_stem = Conv2d(
+            in_channels, out_channels, kernel_size=3, stride=2, bias=False
+        )
+        self._bn0 = nn.BatchNorm2d(
+            num_features=out_channels, momentum=bn_mom, eps=bn_eps
+        )
+        image_size = calculate_output_image_size(image_size, 2)
+
+        # Build blocks
+        self._blocks = nn.ModuleList([])
+        for block_args in self._blocks_args:
+
+            # Update block input and output filters based on depth multiplier.
+            block_args = block_args._replace(
+                input_filters=round_filters(
+                    block_args.input_filters, self._global_params
+                ),
+                output_filters=round_filters(
+                    block_args.output_filters, self._global_params
+                ),
+                num_repeat=round_repeats(block_args.num_repeat, self._global_params),
+            )
+
+            # The first block needs to take care of stride and filter size increase.
+            self._blocks.append(
+                MBConvBlock(block_args, self._global_params, image_size=image_size)
+            )
+            image_size = calculate_output_image_size(image_size, block_args.stride)
+            if block_args.num_repeat > 1:  # modify block_args to keep same output size
+                block_args = block_args._replace(
+                    input_filters=block_args.output_filters, stride=1
+                )
+            for _ in range(block_args.num_repeat - 1):
+                self._blocks.append(
+                    MBConvBlock(block_args, self._global_params, image_size=image_size)
+                )
+                # image_size = calculate_output_image_size(image_size, block_args.stride)  # stride = 1
+
+        self._swish = MemoryEfficientSwish()
+
+    def set_swish(self, memory_efficient=True):
+        """Sets swish function as memory efficient (for training) or standard (for export).
+
+        Args:
+            memory_efficient (bool): Whether to use memory-efficient version of swish.
+
+        """
+        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+        for block in self._blocks:
+            block.set_swish(memory_efficient)
+
+    def extract_endpoints(self, inputs):
+        endpoints = dict()
+
+        # Stem
+        x = self._swish(self._bn0(self._conv_stem(inputs)))
+        prev_x = x
+
+        # Blocks
+        for idx, block in enumerate(self._blocks):
+            drop_connect_rate = self._global_params.drop_connect_rate
+            if drop_connect_rate:
+                drop_connect_rate *= float(idx) / len(
+                    self._blocks
+                )  # scale drop connect_rate
+            x = block(x, drop_connect_rate=drop_connect_rate)
+            if prev_x.size(2) > x.size(2):
+                endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
+            prev_x = x
+
+        # Head
+        x = self._swish(self._bn1(self._conv_head(x)))
+        endpoints["reduction_{}".format(len(endpoints) + 1)] = x
+
+        return endpoints
+
+    def _change_in_channels(self, in_channels):
+        """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
+
+        Args:
+            in_channels (int): Input data's channel number.
+        """
+        if in_channels != 3:
+            Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
+            out_channels = round_filters(32, self._global_params)
+            self._conv_stem = Conv2d(
+                in_channels, out_channels, kernel_size=3, stride=2, bias=False
+            )
+
+
+class EfficientEncoderB7(EfficientNet):
+    def __init__(self):
+        super().__init__(
+            *create_block_args(
+                width_coefficient=2.0,
+                depth_coefficient=3.1,
+                dropout_rate=0.5,
+                image_size=600,
+            )
+        )
+        self._change_in_channels(3)
+        self.block_idx = [10, 17, 37, 54]
+        self.channels = [48, 80, 224, 640]
+
+    def initial_conv(self, inputs):
+        x = self._swish(self._bn0(self._conv_stem(inputs)))
+        return x
+
+    def get_blocks(self, x, H, W, block_idx):
+        features = []
+        for idx, block in enumerate(self._blocks):
+            drop_connect_rate = self._global_params.drop_connect_rate
+            if drop_connect_rate:
+                drop_connect_rate *= float(idx) / len(
+                    self._blocks
+                )  # scale drop connect_rate
+            x = block(x, drop_connect_rate=drop_connect_rate)
+            if idx == block_idx[0]:
+                features.append(x.clone())
+            if idx == block_idx[1]:
+                features.append(x.clone())
+            if idx == block_idx[2]:
+                features.append(x.clone())
+            if idx == block_idx[3]:
+                features.append(x.clone())
+
+        return features
+
+    def forward(self, inputs: torch.Tensor) -> List[Any]:
+        B, C, H, W = inputs.size()
+        x = self.initial_conv(inputs)  # Prepare input for the backbone
+        return self.get_blocks(
+            x, H, W, block_idx=self.block_idx
+        )  # Get backbone features and edge maps
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class EfficientEncoderB7 +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class EfficientEncoderB7(EfficientNet):
+    def __init__(self):
+        super().__init__(
+            *create_block_args(
+                width_coefficient=2.0,
+                depth_coefficient=3.1,
+                dropout_rate=0.5,
+                image_size=600,
+            )
+        )
+        self._change_in_channels(3)
+        self.block_idx = [10, 17, 37, 54]
+        self.channels = [48, 80, 224, 640]
+
+    def initial_conv(self, inputs):
+        x = self._swish(self._bn0(self._conv_stem(inputs)))
+        return x
+
+    def get_blocks(self, x, H, W, block_idx):
+        features = []
+        for idx, block in enumerate(self._blocks):
+            drop_connect_rate = self._global_params.drop_connect_rate
+            if drop_connect_rate:
+                drop_connect_rate *= float(idx) / len(
+                    self._blocks
+                )  # scale drop connect_rate
+            x = block(x, drop_connect_rate=drop_connect_rate)
+            if idx == block_idx[0]:
+                features.append(x.clone())
+            if idx == block_idx[1]:
+                features.append(x.clone())
+            if idx == block_idx[2]:
+                features.append(x.clone())
+            if idx == block_idx[3]:
+                features.append(x.clone())
+
+        return features
+
+    def forward(self, inputs: torch.Tensor) -> List[Any]:
+        B, C, H, W = inputs.size()
+        x = self.initial_conv(inputs)  # Prepare input for the backbone
+        return self.get_blocks(
+            x, H, W, block_idx=self.block_idx
+        )  # Get backbone features and edge maps
+
+

Ancestors

+ +

Methods

+
+
+def forward(self, inputs:Β torch.Tensor) ‑>Β List[Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, inputs: torch.Tensor) -> List[Any]:
+    B, C, H, W = inputs.size()
+    x = self.initial_conv(inputs)  # Prepare input for the backbone
+    return self.get_blocks(
+        x, H, W, block_idx=self.block_idx
+    )  # Get backbone features and edge maps
+
+
+
+def get_blocks(self, x, H, W, block_idx) +
+
+
+
+ +Expand source code + +
def get_blocks(self, x, H, W, block_idx):
+    features = []
+    for idx, block in enumerate(self._blocks):
+        drop_connect_rate = self._global_params.drop_connect_rate
+        if drop_connect_rate:
+            drop_connect_rate *= float(idx) / len(
+                self._blocks
+            )  # scale drop connect_rate
+        x = block(x, drop_connect_rate=drop_connect_rate)
+        if idx == block_idx[0]:
+            features.append(x.clone())
+        if idx == block_idx[1]:
+            features.append(x.clone())
+        if idx == block_idx[2]:
+            features.append(x.clone())
+        if idx == block_idx[3]:
+            features.append(x.clone())
+
+    return features
+
+
+
+def initial_conv(self, inputs) +
+
+
+
+ +Expand source code + +
def initial_conv(self, inputs):
+    x = self._swish(self._bn0(self._conv_stem(inputs)))
+    return x
+
+
+
+

Inherited members

+ +
+
+class EfficientNet +(blocks_args=None, global_params=None) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class EfficientNet(nn.Module):
+    def __init__(self, blocks_args=None, global_params=None):
+        super().__init__()
+        assert isinstance(blocks_args, list), "blocks_args should be a list"
+        assert len(blocks_args) > 0, "block args must be greater than 0"
+        self._global_params = global_params
+        self._blocks_args = blocks_args
+
+        # Batch norm parameters
+        bn_mom = 1 - self._global_params.batch_norm_momentum
+        bn_eps = self._global_params.batch_norm_epsilon
+
+        # Get stem static or dynamic convolution depending on image size
+        image_size = global_params.image_size
+        Conv2d = get_same_padding_conv2d(image_size=image_size)
+
+        # Stem
+        in_channels = 3  # rgb
+        out_channels = round_filters(
+            32, self._global_params
+        )  # number of output channels
+        self._conv_stem = Conv2d(
+            in_channels, out_channels, kernel_size=3, stride=2, bias=False
+        )
+        self._bn0 = nn.BatchNorm2d(
+            num_features=out_channels, momentum=bn_mom, eps=bn_eps
+        )
+        image_size = calculate_output_image_size(image_size, 2)
+
+        # Build blocks
+        self._blocks = nn.ModuleList([])
+        for block_args in self._blocks_args:
+
+            # Update block input and output filters based on depth multiplier.
+            block_args = block_args._replace(
+                input_filters=round_filters(
+                    block_args.input_filters, self._global_params
+                ),
+                output_filters=round_filters(
+                    block_args.output_filters, self._global_params
+                ),
+                num_repeat=round_repeats(block_args.num_repeat, self._global_params),
+            )
+
+            # The first block needs to take care of stride and filter size increase.
+            self._blocks.append(
+                MBConvBlock(block_args, self._global_params, image_size=image_size)
+            )
+            image_size = calculate_output_image_size(image_size, block_args.stride)
+            if block_args.num_repeat > 1:  # modify block_args to keep same output size
+                block_args = block_args._replace(
+                    input_filters=block_args.output_filters, stride=1
+                )
+            for _ in range(block_args.num_repeat - 1):
+                self._blocks.append(
+                    MBConvBlock(block_args, self._global_params, image_size=image_size)
+                )
+                # image_size = calculate_output_image_size(image_size, block_args.stride)  # stride = 1
+
+        self._swish = MemoryEfficientSwish()
+
+    def set_swish(self, memory_efficient=True):
+        """Sets swish function as memory efficient (for training) or standard (for export).
+
+        Args:
+            memory_efficient (bool): Whether to use memory-efficient version of swish.
+
+        """
+        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+        for block in self._blocks:
+            block.set_swish(memory_efficient)
+
+    def extract_endpoints(self, inputs):
+        endpoints = dict()
+
+        # Stem
+        x = self._swish(self._bn0(self._conv_stem(inputs)))
+        prev_x = x
+
+        # Blocks
+        for idx, block in enumerate(self._blocks):
+            drop_connect_rate = self._global_params.drop_connect_rate
+            if drop_connect_rate:
+                drop_connect_rate *= float(idx) / len(
+                    self._blocks
+                )  # scale drop connect_rate
+            x = block(x, drop_connect_rate=drop_connect_rate)
+            if prev_x.size(2) > x.size(2):
+                endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
+            prev_x = x
+
+        # Head
+        x = self._swish(self._bn1(self._conv_head(x)))
+        endpoints["reduction_{}".format(len(endpoints) + 1)] = x
+
+        return endpoints
+
+    def _change_in_channels(self, in_channels):
+        """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
+
+        Args:
+            in_channels (int): Input data's channel number.
+        """
+        if in_channels != 3:
+            Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
+            out_channels = round_filters(32, self._global_params)
+            self._conv_stem = Conv2d(
+                in_channels, out_channels, kernel_size=3, stride=2, bias=False
+            )
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Methods

+
+
+def extract_endpoints(self, inputs) +
+
+
+
+ +Expand source code + +
def extract_endpoints(self, inputs):
+    endpoints = dict()
+
+    # Stem
+    x = self._swish(self._bn0(self._conv_stem(inputs)))
+    prev_x = x
+
+    # Blocks
+    for idx, block in enumerate(self._blocks):
+        drop_connect_rate = self._global_params.drop_connect_rate
+        if drop_connect_rate:
+            drop_connect_rate *= float(idx) / len(
+                self._blocks
+            )  # scale drop connect_rate
+        x = block(x, drop_connect_rate=drop_connect_rate)
+        if prev_x.size(2) > x.size(2):
+            endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
+        prev_x = x
+
+    # Head
+    x = self._swish(self._bn1(self._conv_head(x)))
+    endpoints["reduction_{}".format(len(endpoints) + 1)] = x
+
+    return endpoints
+
+
+
+def set_swish(self, memory_efficient=True) +
+
+

Sets swish function as memory efficient (for training) or standard (for export).

+

Args

+
+
memory_efficient : bool
+
Whether to use memory-efficient version of swish.
+
+
+ +Expand source code + +
def set_swish(self, memory_efficient=True):
+    """Sets swish function as memory efficient (for training) or standard (for export).
+
+    Args:
+        memory_efficient (bool): Whether to use memory-efficient version of swish.
+
+    """
+    self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+    for block in self._blocks:
+        block.set_swish(memory_efficient)
+
+
+
+
+
+class MBConvBlock +(block_args, global_params, image_size=None) +
+
+

Mobile Inverted Residual Bottleneck Block.

+

Args

+
+
block_args : namedtuple
+
BlockArgs, defined in utils.py.
+
global_params : namedtuple
+
GlobalParam, defined in utils.py.
+
image_size : tuple or list
+
[image_height, image_width].
+
+

References

+

[1] https://arxiv.org/abs/1704.04861 (MobileNet v1) +[2] https://arxiv.org/abs/1801.04381 (MobileNet v2) +[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class MBConvBlock(nn.Module):
+    """Mobile Inverted Residual Bottleneck Block.
+
+    Args:
+        block_args (namedtuple): BlockArgs, defined in utils.py.
+        global_params (namedtuple): GlobalParam, defined in utils.py.
+        image_size (tuple or list): [image_height, image_width].
+
+    References:
+        [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
+        [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
+        [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
+    """
+
+    def __init__(self, block_args, global_params, image_size=None):
+        super().__init__()
+        self._block_args = block_args
+        self._bn_mom = (
+            1 - global_params.batch_norm_momentum
+        )  # pytorch's difference from tensorflow
+        self._bn_eps = global_params.batch_norm_epsilon
+        self.has_se = (self._block_args.se_ratio is not None) and (
+            0 < self._block_args.se_ratio <= 1
+        )
+        self.id_skip = (
+            block_args.id_skip
+        )  # whether to use skip connection and drop connect
+
+        # Expansion phase (Inverted Bottleneck)
+        inp = self._block_args.input_filters  # number of input channels
+        oup = (
+            self._block_args.input_filters * self._block_args.expand_ratio
+        )  # number of output channels
+        if self._block_args.expand_ratio != 1:
+            Conv2d = get_same_padding_conv2d(image_size=image_size)
+            self._expand_conv = Conv2d(
+                in_channels=inp, out_channels=oup, kernel_size=1, bias=False
+            )
+            self._bn0 = nn.BatchNorm2d(
+                num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
+            )
+            # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
+
+        # Depthwise convolution phase
+        k = self._block_args.kernel_size
+        s = self._block_args.stride
+        Conv2d = get_same_padding_conv2d(image_size=image_size)
+        self._depthwise_conv = Conv2d(
+            in_channels=oup,
+            out_channels=oup,
+            groups=oup,  # groups makes it depthwise
+            kernel_size=k,
+            stride=s,
+            bias=False,
+        )
+        self._bn1 = nn.BatchNorm2d(
+            num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
+        )
+        image_size = calculate_output_image_size(image_size, s)
+
+        # Squeeze and Excitation layer, if desired
+        if self.has_se:
+            Conv2d = get_same_padding_conv2d(image_size=(1, 1))
+            num_squeezed_channels = max(
+                1, int(self._block_args.input_filters * self._block_args.se_ratio)
+            )
+            self._se_reduce = Conv2d(
+                in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1
+            )
+            self._se_expand = Conv2d(
+                in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1
+            )
+
+        # Pointwise convolution phase
+        final_oup = self._block_args.output_filters
+        Conv2d = get_same_padding_conv2d(image_size=image_size)
+        self._project_conv = Conv2d(
+            in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
+        )
+        self._bn2 = nn.BatchNorm2d(
+            num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
+        )
+        self._swish = MemoryEfficientSwish()
+
+    def forward(self, inputs, drop_connect_rate=None):
+        """MBConvBlock's forward function.
+
+        Args:
+            inputs (tensor): Input tensor.
+            drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
+
+        Returns:
+            Output of this block after processing.
+        """
+
+        # Expansion and Depthwise Convolution
+        x = inputs
+        if self._block_args.expand_ratio != 1:
+            x = self._expand_conv(inputs)
+            x = self._bn0(x)
+            x = self._swish(x)
+
+        x = self._depthwise_conv(x)
+        x = self._bn1(x)
+        x = self._swish(x)
+
+        # Squeeze and Excitation
+        if self.has_se:
+            x_squeezed = F.adaptive_avg_pool2d(x, 1)
+            x_squeezed = self._se_reduce(x_squeezed)
+            x_squeezed = self._swish(x_squeezed)
+            x_squeezed = self._se_expand(x_squeezed)
+            x = torch.sigmoid(x_squeezed) * x
+
+        # Pointwise Convolution
+        x = self._project_conv(x)
+        x = self._bn2(x)
+
+        # Skip connection and drop connect
+        input_filters, output_filters = (
+            self._block_args.input_filters,
+            self._block_args.output_filters,
+        )
+        if (
+            self.id_skip
+            and self._block_args.stride == 1
+            and input_filters == output_filters
+        ):
+            # The combination of skip connection and drop connect brings about stochastic depth.
+            if drop_connect_rate:
+                x = drop_connect(x, p=drop_connect_rate, training=self.training)
+            x = x + inputs  # skip connection
+        return x
+
+    def set_swish(self, memory_efficient=True):
+        """Sets swish function as memory efficient (for training) or standard (for export).
+
+        Args:
+            memory_efficient (bool): Whether to use memory-efficient version of swish.
+        """
+        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, inputs, drop_connect_rate=None) ‑>Β Callable[...,Β Any] +
+
+

MBConvBlock's forward function.

+

Args

+
+
inputs : tensor
+
Input tensor.
+
drop_connect_rate : bool
+
Drop connect rate (float, between 0 and 1).
+
+

Returns

+

Output of this block after processing.

+
+ +Expand source code + +
def forward(self, inputs, drop_connect_rate=None):
+    """MBConvBlock's forward function.
+
+    Args:
+        inputs (tensor): Input tensor.
+        drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
+
+    Returns:
+        Output of this block after processing.
+    """
+
+    # Expansion and Depthwise Convolution
+    x = inputs
+    if self._block_args.expand_ratio != 1:
+        x = self._expand_conv(inputs)
+        x = self._bn0(x)
+        x = self._swish(x)
+
+    x = self._depthwise_conv(x)
+    x = self._bn1(x)
+    x = self._swish(x)
+
+    # Squeeze and Excitation
+    if self.has_se:
+        x_squeezed = F.adaptive_avg_pool2d(x, 1)
+        x_squeezed = self._se_reduce(x_squeezed)
+        x_squeezed = self._swish(x_squeezed)
+        x_squeezed = self._se_expand(x_squeezed)
+        x = torch.sigmoid(x_squeezed) * x
+
+    # Pointwise Convolution
+    x = self._project_conv(x)
+    x = self._bn2(x)
+
+    # Skip connection and drop connect
+    input_filters, output_filters = (
+        self._block_args.input_filters,
+        self._block_args.output_filters,
+    )
+    if (
+        self.id_skip
+        and self._block_args.stride == 1
+        and input_filters == output_filters
+    ):
+        # The combination of skip connection and drop connect brings about stochastic depth.
+        if drop_connect_rate:
+            x = drop_connect(x, p=drop_connect_rate, training=self.training)
+        x = x + inputs  # skip connection
+    return x
+
+
+
+def set_swish(self, memory_efficient=True) +
+
+

Sets swish function as memory efficient (for training) or standard (for export).

+

Args

+
+
memory_efficient : bool
+
Whether to use memory-efficient version of swish.
+
+
+ +Expand source code + +
def set_swish(self, memory_efficient=True):
+    """Sets swish function as memory efficient (for training) or standard (for export).
+
+    Args:
+        memory_efficient (bool): Whether to use memory-efficient version of swish.
+    """
+    self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/tracerb7/index.html b/docs/api/carvekit/ml/arch/tracerb7/index.html new file mode 100644 index 0000000..69a9c40 --- /dev/null +++ b/docs/api/carvekit/ml/arch/tracerb7/index.html @@ -0,0 +1,96 @@ + + + + + + +carvekit.ml.arch.tracerb7 API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.tracerb7

+
+
+
+
+

Sub-modules

+
+
carvekit.ml.arch.tracerb7.att_modules
+
+

Source url: https://github.com/Karel911/TRACER +Author: Min Seok Lee and Wooseok Shin +License: Apache License 2.0

+
+
carvekit.ml.arch.tracerb7.conv_modules
+
+

Source url: https://github.com/Karel911/TRACER +Author: Min Seok Lee and Wooseok Shin +License: Apache License 2.0

+
+
carvekit.ml.arch.tracerb7.effi_utils
+
+

Original author: lukemelas (github username) +Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +With adjustments and added comments by …

+
+
carvekit.ml.arch.tracerb7.efficientnet
+
+

Source url: https://github.com/lukemelas/EfficientNet-PyTorch +Modified by Min Seok Lee, Wooseok Shin, Nikita Selin +License: Apache License 2.0 +…

+
+
carvekit.ml.arch.tracerb7.tracer
+
+

Source url: https://github.com/Karel911/TRACER +Author: Min Seok Lee and Wooseok Shin +Modified by Nikita Selin …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/tracerb7/tracer.html b/docs/api/carvekit/ml/arch/tracerb7/tracer.html new file mode 100644 index 0000000..8885c88 --- /dev/null +++ b/docs/api/carvekit/ml/arch/tracerb7/tracer.html @@ -0,0 +1,332 @@ + + + + + + +carvekit.ml.arch.tracerb7.tracer API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.tracerb7.tracer

+
+
+

Source url: https://github.com/Karel911/TRACER +Author: Min Seok Lee and Wooseok Shin +Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+

Changes

+
    +
  • Refactored code
  • +
  • Removed unused code
  • +
  • Added comments
  • +
+
+ +Expand source code + +
"""
+Source url: https://github.com/Karel911/TRACER
+Author: Min Seok Lee and Wooseok Shin
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+Changes:
+    - Refactored code
+    - Removed unused code
+    - Added comments
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List, Optional, Tuple
+
+from torch import Tensor
+
+from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7
+from carvekit.ml.arch.tracerb7.att_modules import (
+    RFB_Block,
+    aggregation,
+    ObjectAttention,
+)
+
+
+class TracerDecoder(nn.Module):
+    """Tracer Decoder"""
+
+    def __init__(
+        self,
+        encoder: EfficientEncoderB7,
+        features_channels: Optional[List[int]] = None,
+        rfb_channel: Optional[List[int]] = None,
+    ):
+        """
+        Initialize the tracer decoder.
+
+        Args:
+            encoder: The encoder to use.
+            features_channels: The channels of the backbone features at different stages. default: [48, 80, 224, 640]
+            rfb_channel: The channels of the RFB features. default: [32, 64, 128]
+        """
+        super().__init__()
+        if rfb_channel is None:
+            rfb_channel = [32, 64, 128]
+        if features_channels is None:
+            features_channels = [48, 80, 224, 640]
+        self.encoder = encoder
+        self.features_channels = features_channels
+
+        # Receptive Field Blocks
+        features_channels = rfb_channel
+        self.rfb2 = RFB_Block(self.features_channels[1], features_channels[0])
+        self.rfb3 = RFB_Block(self.features_channels[2], features_channels[1])
+        self.rfb4 = RFB_Block(self.features_channels[3], features_channels[2])
+
+        # Multi-level aggregation
+        self.agg = aggregation(features_channels)
+
+        # Object Attention
+        self.ObjectAttention2 = ObjectAttention(
+            channel=self.features_channels[1], kernel_size=3
+        )
+        self.ObjectAttention1 = ObjectAttention(
+            channel=self.features_channels[0], kernel_size=3
+        )
+
+    def forward(self, inputs: torch.Tensor) -> Tensor:
+        """
+        Forward pass of the tracer decoder.
+
+        Args:
+            inputs: Preprocessed images.
+
+        Returns:
+            Tensors of segmentation masks and mask of object edges.
+        """
+        features = self.encoder(inputs)
+        x3_rfb = self.rfb2(features[1])
+        x4_rfb = self.rfb3(features[2])
+        x5_rfb = self.rfb4(features[3])
+
+        D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb)
+
+        ds_map0 = F.interpolate(D_0, scale_factor=8, mode="bilinear")
+
+        D_1 = self.ObjectAttention2(D_0, features[1])
+        ds_map1 = F.interpolate(D_1, scale_factor=8, mode="bilinear")
+
+        ds_map = F.interpolate(D_1, scale_factor=2, mode="bilinear")
+        D_2 = self.ObjectAttention1(ds_map, features[0])
+        ds_map2 = F.interpolate(D_2, scale_factor=4, mode="bilinear")
+
+        final_map = (ds_map2 + ds_map1 + ds_map0) / 3
+
+        return torch.sigmoid(final_map)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class TracerDecoder +(encoder:Β EfficientEncoderB7, features_channels:Β Optional[List[int]]Β =Β None, rfb_channel:Β Optional[List[int]]Β =Β None) +
+
+

Tracer Decoder

+

Initialize the tracer decoder.

+

Args

+
+
encoder
+
The encoder to use.
+
features_channels
+
The channels of the backbone features at different stages. default: [48, 80, 224, 640]
+
rfb_channel
+
The channels of the RFB features. default: [32, 64, 128]
+
+
+ +Expand source code + +
class TracerDecoder(nn.Module):
+    """Tracer Decoder"""
+
+    def __init__(
+        self,
+        encoder: EfficientEncoderB7,
+        features_channels: Optional[List[int]] = None,
+        rfb_channel: Optional[List[int]] = None,
+    ):
+        """
+        Initialize the tracer decoder.
+
+        Args:
+            encoder: The encoder to use.
+            features_channels: The channels of the backbone features at different stages. default: [48, 80, 224, 640]
+            rfb_channel: The channels of the RFB features. default: [32, 64, 128]
+        """
+        super().__init__()
+        if rfb_channel is None:
+            rfb_channel = [32, 64, 128]
+        if features_channels is None:
+            features_channels = [48, 80, 224, 640]
+        self.encoder = encoder
+        self.features_channels = features_channels
+
+        # Receptive Field Blocks
+        features_channels = rfb_channel
+        self.rfb2 = RFB_Block(self.features_channels[1], features_channels[0])
+        self.rfb3 = RFB_Block(self.features_channels[2], features_channels[1])
+        self.rfb4 = RFB_Block(self.features_channels[3], features_channels[2])
+
+        # Multi-level aggregation
+        self.agg = aggregation(features_channels)
+
+        # Object Attention
+        self.ObjectAttention2 = ObjectAttention(
+            channel=self.features_channels[1], kernel_size=3
+        )
+        self.ObjectAttention1 = ObjectAttention(
+            channel=self.features_channels[0], kernel_size=3
+        )
+
+    def forward(self, inputs: torch.Tensor) -> Tensor:
+        """
+        Forward pass of the tracer decoder.
+
+        Args:
+            inputs: Preprocessed images.
+
+        Returns:
+            Tensors of segmentation masks and mask of object edges.
+        """
+        features = self.encoder(inputs)
+        x3_rfb = self.rfb2(features[1])
+        x4_rfb = self.rfb3(features[2])
+        x5_rfb = self.rfb4(features[3])
+
+        D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb)
+
+        ds_map0 = F.interpolate(D_0, scale_factor=8, mode="bilinear")
+
+        D_1 = self.ObjectAttention2(D_0, features[1])
+        ds_map1 = F.interpolate(D_1, scale_factor=8, mode="bilinear")
+
+        ds_map = F.interpolate(D_1, scale_factor=2, mode="bilinear")
+        D_2 = self.ObjectAttention1(ds_map, features[0])
+        ds_map2 = F.interpolate(D_2, scale_factor=4, mode="bilinear")
+
+        final_map = (ds_map2 + ds_map1 + ds_map0) / 3
+
+        return torch.sigmoid(final_map)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Methods

+
+
+def forward(self, inputs:Β torch.Tensor) ‑>Β torch.Tensor +
+
+

Forward pass of the tracer decoder.

+

Args

+
+
inputs
+
Preprocessed images.
+
+

Returns

+

Tensors of segmentation masks and mask of object edges.

+
+ +Expand source code + +
def forward(self, inputs: torch.Tensor) -> Tensor:
+    """
+    Forward pass of the tracer decoder.
+
+    Args:
+        inputs: Preprocessed images.
+
+    Returns:
+        Tensors of segmentation masks and mask of object edges.
+    """
+    features = self.encoder(inputs)
+    x3_rfb = self.rfb2(features[1])
+    x4_rfb = self.rfb3(features[2])
+    x5_rfb = self.rfb4(features[3])
+
+    D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb)
+
+    ds_map0 = F.interpolate(D_0, scale_factor=8, mode="bilinear")
+
+    D_1 = self.ObjectAttention2(D_0, features[1])
+    ds_map1 = F.interpolate(D_1, scale_factor=8, mode="bilinear")
+
+    ds_map = F.interpolate(D_1, scale_factor=2, mode="bilinear")
+    D_2 = self.ObjectAttention1(ds_map, features[0])
+    ds_map2 = F.interpolate(D_2, scale_factor=4, mode="bilinear")
+
+    final_map = (ds_map2 + ds_map1 + ds_map0) / 3
+
+    return torch.sigmoid(final_map)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/u2net/index.html b/docs/api/carvekit/ml/arch/u2net/index.html new file mode 100644 index 0000000..468a43f --- /dev/null +++ b/docs/api/carvekit/ml/arch/u2net/index.html @@ -0,0 +1,67 @@ + + + + + + +carvekit.ml.arch.u2net API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.u2net

+
+
+
+
+

Sub-modules

+
+
carvekit.ml.arch.u2net.u2net
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/xuebinqin/U-2-Net +License: Apache License 2.0

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/u2net/u2net.html b/docs/api/carvekit/ml/arch/u2net/u2net.html new file mode 100644 index 0000000..c1dfb98 --- /dev/null +++ b/docs/api/carvekit/ml/arch/u2net/u2net.html @@ -0,0 +1,431 @@ + + + + + + +carvekit.ml.arch.u2net.u2net API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.u2net.u2net

+
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/xuebinqin/U-2-Net +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+Source url: https://github.com/xuebinqin/U-2-Net
+License: Apache License 2.0
+"""
+from typing import Union
+
+import torch
+import torch.nn as nn
+
+import math
+
+__all__ = ["U2NETArchitecture"]
+
+
+def _upsample_like(x, size):
+    return nn.Upsample(size=size, mode="bilinear", align_corners=False)(x)
+
+
+def _size_map(x, height):
+    # {height: size} for Upsample
+    size = list(x.shape[-2:])
+    sizes = {}
+    for h in range(1, height):
+        sizes[h] = size
+        size = [math.ceil(w / 2) for w in size]
+    return sizes
+
+
+class REBNCONV(nn.Module):
+    def __init__(self, in_ch=3, out_ch=3, dilate=1):
+        super(REBNCONV, self).__init__()
+
+        self.conv_s1 = nn.Conv2d(
+            in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate
+        )
+        self.bn_s1 = nn.BatchNorm2d(out_ch)
+        self.relu_s1 = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        return self.relu_s1(self.bn_s1(self.conv_s1(x)))
+
+
+class RSU(nn.Module):
+    def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
+        super(RSU, self).__init__()
+        self.name = name
+        self.height = height
+        self.dilated = dilated
+        self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
+
+    def forward(self, x):
+        sizes = _size_map(x, self.height)
+        x = self.rebnconvin(x)
+
+        # U-Net like symmetric encoder-decoder structure
+        def unet(x, height=1):
+            if height < self.height:
+                x1 = getattr(self, f"rebnconv{height}")(x)
+                if not self.dilated and height < self.height - 1:
+                    x2 = unet(getattr(self, "downsample")(x1), height + 1)
+                else:
+                    x2 = unet(x1, height + 1)
+
+                x = getattr(self, f"rebnconv{height}d")(torch.cat((x2, x1), 1))
+                return (
+                    _upsample_like(x, sizes[height - 1])
+                    if not self.dilated and height > 1
+                    else x
+                )
+            else:
+                return getattr(self, f"rebnconv{height}")(x)
+
+        return x + unet(x)
+
+    def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
+        self.add_module("rebnconvin", REBNCONV(in_ch, out_ch))
+        self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True))
+
+        self.add_module("rebnconv1", REBNCONV(out_ch, mid_ch))
+        self.add_module("rebnconv1d", REBNCONV(mid_ch * 2, out_ch))
+
+        for i in range(2, height):
+            dilate = 1 if not dilated else 2 ** (i - 1)
+            self.add_module(f"rebnconv{i}", REBNCONV(mid_ch, mid_ch, dilate=dilate))
+            self.add_module(
+                f"rebnconv{i}d", REBNCONV(mid_ch * 2, mid_ch, dilate=dilate)
+            )
+
+        dilate = 2 if not dilated else 2 ** (height - 1)
+        self.add_module(f"rebnconv{height}", REBNCONV(mid_ch, mid_ch, dilate=dilate))
+
+
+class U2NETArchitecture(nn.Module):
+    def __init__(self, cfg_type: Union[dict, str] = "full", out_ch: int = 1):
+        super(U2NETArchitecture, self).__init__()
+        if isinstance(cfg_type, str):
+            if cfg_type == "full":
+                layers_cfgs = {
+                    # cfgs for building RSUs and sides
+                    # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
+                    "stage1": ["En_1", (7, 3, 32, 64), -1],
+                    "stage2": ["En_2", (6, 64, 32, 128), -1],
+                    "stage3": ["En_3", (5, 128, 64, 256), -1],
+                    "stage4": ["En_4", (4, 256, 128, 512), -1],
+                    "stage5": ["En_5", (4, 512, 256, 512, True), -1],
+                    "stage6": ["En_6", (4, 512, 256, 512, True), 512],
+                    "stage5d": ["De_5", (4, 1024, 256, 512, True), 512],
+                    "stage4d": ["De_4", (4, 1024, 128, 256), 256],
+                    "stage3d": ["De_3", (5, 512, 64, 128), 128],
+                    "stage2d": ["De_2", (6, 256, 32, 64), 64],
+                    "stage1d": ["De_1", (7, 128, 16, 64), 64],
+                }
+            else:
+                raise ValueError("Unknown U^2-Net architecture conf. name")
+        elif isinstance(cfg_type, dict):
+            layers_cfgs = cfg_type
+        else:
+            raise ValueError("Unknown U^2-Net architecture conf. type")
+        self.out_ch = out_ch
+        self._make_layers(layers_cfgs)
+
+    def forward(self, x):
+        sizes = _size_map(x, self.height)
+        maps = []  # storage for maps
+
+        # side saliency map
+        def unet(x, height=1):
+            if height < 6:
+                x1 = getattr(self, f"stage{height}")(x)
+                x2 = unet(getattr(self, "downsample")(x1), height + 1)
+                x = getattr(self, f"stage{height}d")(torch.cat((x2, x1), 1))
+                side(x, height)
+                return _upsample_like(x, sizes[height - 1]) if height > 1 else x
+            else:
+                x = getattr(self, f"stage{height}")(x)
+                side(x, height)
+                return _upsample_like(x, sizes[height - 1])
+
+        def side(x, h):
+            # side output saliency map (before sigmoid)
+            x = getattr(self, f"side{h}")(x)
+            x = _upsample_like(x, sizes[1])
+            maps.append(x)
+
+        def fuse():
+            # fuse saliency probability maps
+            maps.reverse()
+            x = torch.cat(maps, 1)
+            x = getattr(self, "outconv")(x)
+            maps.insert(0, x)
+            return [torch.sigmoid(x) for x in maps]
+
+        unet(x)
+        maps = fuse()
+        return maps
+
+    def _make_layers(self, cfgs):
+        self.height = int((len(cfgs) + 1) / 2)
+        self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True))
+        for k, v in cfgs.items():
+            # build rsu block
+            self.add_module(k, RSU(v[0], *v[1]))
+            if v[2] > 0:
+                # build side layer
+                self.add_module(
+                    f"side{v[0][-1]}", nn.Conv2d(v[2], self.out_ch, 3, padding=1)
+                )
+        # build fuse layer
+        self.add_module(
+            "outconv", nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1)
+        )
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class U2NETArchitecture +(cfg_type:Β Union[dict,Β str]Β =Β 'full', out_ch:Β intΒ =Β 1) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class U2NETArchitecture(nn.Module):
+    def __init__(self, cfg_type: Union[dict, str] = "full", out_ch: int = 1):
+        super(U2NETArchitecture, self).__init__()
+        if isinstance(cfg_type, str):
+            if cfg_type == "full":
+                layers_cfgs = {
+                    # cfgs for building RSUs and sides
+                    # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
+                    "stage1": ["En_1", (7, 3, 32, 64), -1],
+                    "stage2": ["En_2", (6, 64, 32, 128), -1],
+                    "stage3": ["En_3", (5, 128, 64, 256), -1],
+                    "stage4": ["En_4", (4, 256, 128, 512), -1],
+                    "stage5": ["En_5", (4, 512, 256, 512, True), -1],
+                    "stage6": ["En_6", (4, 512, 256, 512, True), 512],
+                    "stage5d": ["De_5", (4, 1024, 256, 512, True), 512],
+                    "stage4d": ["De_4", (4, 1024, 128, 256), 256],
+                    "stage3d": ["De_3", (5, 512, 64, 128), 128],
+                    "stage2d": ["De_2", (6, 256, 32, 64), 64],
+                    "stage1d": ["De_1", (7, 128, 16, 64), 64],
+                }
+            else:
+                raise ValueError("Unknown U^2-Net architecture conf. name")
+        elif isinstance(cfg_type, dict):
+            layers_cfgs = cfg_type
+        else:
+            raise ValueError("Unknown U^2-Net architecture conf. type")
+        self.out_ch = out_ch
+        self._make_layers(layers_cfgs)
+
+    def forward(self, x):
+        sizes = _size_map(x, self.height)
+        maps = []  # storage for maps
+
+        # side saliency map
+        def unet(x, height=1):
+            if height < 6:
+                x1 = getattr(self, f"stage{height}")(x)
+                x2 = unet(getattr(self, "downsample")(x1), height + 1)
+                x = getattr(self, f"stage{height}d")(torch.cat((x2, x1), 1))
+                side(x, height)
+                return _upsample_like(x, sizes[height - 1]) if height > 1 else x
+            else:
+                x = getattr(self, f"stage{height}")(x)
+                side(x, height)
+                return _upsample_like(x, sizes[height - 1])
+
+        def side(x, h):
+            # side output saliency map (before sigmoid)
+            x = getattr(self, f"side{h}")(x)
+            x = _upsample_like(x, sizes[1])
+            maps.append(x)
+
+        def fuse():
+            # fuse saliency probability maps
+            maps.reverse()
+            x = torch.cat(maps, 1)
+            x = getattr(self, "outconv")(x)
+            maps.insert(0, x)
+            return [torch.sigmoid(x) for x in maps]
+
+        unet(x)
+        maps = fuse()
+        return maps
+
+    def _make_layers(self, cfgs):
+        self.height = int((len(cfgs) + 1) / 2)
+        self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True))
+        for k, v in cfgs.items():
+            # build rsu block
+            self.add_module(k, RSU(v[0], *v[1]))
+            if v[2] > 0:
+                # build side layer
+                self.add_module(
+                    f"side{v[0][-1]}", nn.Conv2d(v[2], self.out_ch, 3, padding=1)
+                )
+        # build fuse layer
+        self.add_module(
+            "outconv", nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1)
+        )
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    sizes = _size_map(x, self.height)
+    maps = []  # storage for maps
+
+    # side saliency map
+    def unet(x, height=1):
+        if height < 6:
+            x1 = getattr(self, f"stage{height}")(x)
+            x2 = unet(getattr(self, "downsample")(x1), height + 1)
+            x = getattr(self, f"stage{height}d")(torch.cat((x2, x1), 1))
+            side(x, height)
+            return _upsample_like(x, sizes[height - 1]) if height > 1 else x
+        else:
+            x = getattr(self, f"stage{height}")(x)
+            side(x, height)
+            return _upsample_like(x, sizes[height - 1])
+
+    def side(x, h):
+        # side output saliency map (before sigmoid)
+        x = getattr(self, f"side{h}")(x)
+        x = _upsample_like(x, sizes[1])
+        maps.append(x)
+
+    def fuse():
+        # fuse saliency probability maps
+        maps.reverse()
+        x = torch.cat(maps, 1)
+        x = getattr(self, "outconv")(x)
+        maps.insert(0, x)
+        return [torch.sigmoid(x) for x in maps]
+
+    unet(x)
+    maps = fuse()
+    return maps
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/yolov4/index.html b/docs/api/carvekit/ml/arch/yolov4/index.html new file mode 100644 index 0000000..10cd438 --- /dev/null +++ b/docs/api/carvekit/ml/arch/yolov4/index.html @@ -0,0 +1,79 @@ + + + + + + +carvekit.ml.arch.yolov4 API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.yolov4

+
+
+
+
+

Sub-modules

+
+
carvekit.ml.arch.yolov4.models
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/Tianxiaomo/pytorch-YOLOv4 +License: Apache License …

+
+
carvekit.ml.arch.yolov4.utils
+
+
+
+
carvekit.ml.arch.yolov4.yolo_layer
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/Tianxiaomo/pytorch-YOLOv4 +License: Apache License …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/yolov4/models.html b/docs/api/carvekit/ml/arch/yolov4/models.html new file mode 100644 index 0000000..5974ec3 --- /dev/null +++ b/docs/api/carvekit/ml/arch/yolov4/models.html @@ -0,0 +1,2193 @@ + + + + + + +carvekit.ml.arch.yolov4.models API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.yolov4.models

+
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/Tianxiaomo/pytorch-YOLOv4 +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+Source url: https://github.com/Tianxiaomo/pytorch-YOLOv4
+License: Apache License 2.0
+"""
+import torch
+from torch import nn
+import torch.nn.functional as F
+from carvekit.ml.arch.yolov4.yolo_layer import YoloLayer
+
+
+def get_region_boxes(boxes_and_confs):
+    # print('Getting boxes from boxes and confs ...')
+
+    boxes_list = []
+    confs_list = []
+
+    for item in boxes_and_confs:
+        boxes_list.append(item[0])
+        confs_list.append(item[1])
+
+    # boxes: [batch, num1 + num2 + num3, 1, 4]
+    # confs: [batch, num1 + num2 + num3, num_classes]
+    boxes = torch.cat(boxes_list, dim=1)
+    confs = torch.cat(confs_list, dim=1)
+
+    return [boxes, confs]
+
+
+class Mish(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        x = x * (torch.tanh(torch.nn.functional.softplus(x)))
+        return x
+
+
+class Upsample(nn.Module):
+    def __init__(self):
+        super(Upsample, self).__init__()
+
+    def forward(self, x, target_size, inference=False):
+        assert x.data.dim() == 4
+        # _, _, tH, tW = target_size
+
+        if inference:
+
+            # B = x.data.size(0)
+            # C = x.data.size(1)
+            # H = x.data.size(2)
+            # W = x.data.size(3)
+
+            return (
+                x.view(x.size(0), x.size(1), x.size(2), 1, x.size(3), 1)
+                .expand(
+                    x.size(0),
+                    x.size(1),
+                    x.size(2),
+                    target_size[2] // x.size(2),
+                    x.size(3),
+                    target_size[3] // x.size(3),
+                )
+                .contiguous()
+                .view(x.size(0), x.size(1), target_size[2], target_size[3])
+            )
+        else:
+            return F.interpolate(
+                x, size=(target_size[2], target_size[3]), mode="nearest"
+            )
+
+
+class Conv_Bn_Activation(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride,
+        activation,
+        bn=True,
+        bias=False,
+    ):
+        super().__init__()
+        pad = (kernel_size - 1) // 2
+
+        self.conv = nn.ModuleList()
+        if bias:
+            self.conv.append(
+                nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad)
+            )
+        else:
+            self.conv.append(
+                nn.Conv2d(
+                    in_channels, out_channels, kernel_size, stride, pad, bias=False
+                )
+            )
+        if bn:
+            self.conv.append(nn.BatchNorm2d(out_channels))
+        if activation == "mish":
+            self.conv.append(Mish())
+        elif activation == "relu":
+            self.conv.append(nn.ReLU(inplace=True))
+        elif activation == "leaky":
+            self.conv.append(nn.LeakyReLU(0.1, inplace=True))
+        elif activation == "linear":
+            pass
+        else:
+            raise Exception("activation error")
+
+    def forward(self, x):
+        for l in self.conv:
+            x = l(x)
+        return x
+
+
+class ResBlock(nn.Module):
+    """
+    Sequential residual blocks each of which consists of \
+    two convolution layers.
+    Args:
+        ch (int): number of input and output channels.
+        nblocks (int): number of residual blocks.
+        shortcut (bool): if True, residual tensor addition is enabled.
+    """
+
+    def __init__(self, ch, nblocks=1, shortcut=True):
+        super().__init__()
+        self.shortcut = shortcut
+        self.module_list = nn.ModuleList()
+        for i in range(nblocks):
+            resblock_one = nn.ModuleList()
+            resblock_one.append(Conv_Bn_Activation(ch, ch, 1, 1, "mish"))
+            resblock_one.append(Conv_Bn_Activation(ch, ch, 3, 1, "mish"))
+            self.module_list.append(resblock_one)
+
+    def forward(self, x):
+        for module in self.module_list:
+            h = x
+            for res in module:
+                h = res(h)
+            x = x + h if self.shortcut else h
+        return x
+
+
+class DownSample1(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = Conv_Bn_Activation(3, 32, 3, 1, "mish")
+
+        self.conv2 = Conv_Bn_Activation(32, 64, 3, 2, "mish")
+        self.conv3 = Conv_Bn_Activation(64, 64, 1, 1, "mish")
+        # [route]
+        # layers = -2
+        self.conv4 = Conv_Bn_Activation(64, 64, 1, 1, "mish")
+
+        self.conv5 = Conv_Bn_Activation(64, 32, 1, 1, "mish")
+        self.conv6 = Conv_Bn_Activation(32, 64, 3, 1, "mish")
+        # [shortcut]
+        # from=-3
+        # activation = linear
+
+        self.conv7 = Conv_Bn_Activation(64, 64, 1, 1, "mish")
+        # [route]
+        # layers = -1, -7
+        self.conv8 = Conv_Bn_Activation(128, 64, 1, 1, "mish")
+
+    def forward(self, input):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x2)
+        # route -2
+        x4 = self.conv4(x2)
+        x5 = self.conv5(x4)
+        x6 = self.conv6(x5)
+        # shortcut -3
+        x6 = x6 + x4
+
+        x7 = self.conv7(x6)
+        # [route]
+        # layers = -1, -7
+        x7 = torch.cat([x7, x3], dim=1)
+        x8 = self.conv8(x7)
+        return x8
+
+
+class DownSample2(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = Conv_Bn_Activation(64, 128, 3, 2, "mish")
+        self.conv2 = Conv_Bn_Activation(128, 64, 1, 1, "mish")
+        # r -2
+        self.conv3 = Conv_Bn_Activation(128, 64, 1, 1, "mish")
+
+        self.resblock = ResBlock(ch=64, nblocks=2)
+
+        # s -3
+        self.conv4 = Conv_Bn_Activation(64, 64, 1, 1, "mish")
+        # r -1 -10
+        self.conv5 = Conv_Bn_Activation(128, 128, 1, 1, "mish")
+
+    def forward(self, input):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x1)
+
+        r = self.resblock(x3)
+        x4 = self.conv4(r)
+
+        x4 = torch.cat([x4, x2], dim=1)
+        x5 = self.conv5(x4)
+        return x5
+
+
+class DownSample3(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = Conv_Bn_Activation(128, 256, 3, 2, "mish")
+        self.conv2 = Conv_Bn_Activation(256, 128, 1, 1, "mish")
+        self.conv3 = Conv_Bn_Activation(256, 128, 1, 1, "mish")
+
+        self.resblock = ResBlock(ch=128, nblocks=8)
+        self.conv4 = Conv_Bn_Activation(128, 128, 1, 1, "mish")
+        self.conv5 = Conv_Bn_Activation(256, 256, 1, 1, "mish")
+
+    def forward(self, input):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x1)
+
+        r = self.resblock(x3)
+        x4 = self.conv4(r)
+
+        x4 = torch.cat([x4, x2], dim=1)
+        x5 = self.conv5(x4)
+        return x5
+
+
+class DownSample4(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = Conv_Bn_Activation(256, 512, 3, 2, "mish")
+        self.conv2 = Conv_Bn_Activation(512, 256, 1, 1, "mish")
+        self.conv3 = Conv_Bn_Activation(512, 256, 1, 1, "mish")
+
+        self.resblock = ResBlock(ch=256, nblocks=8)
+        self.conv4 = Conv_Bn_Activation(256, 256, 1, 1, "mish")
+        self.conv5 = Conv_Bn_Activation(512, 512, 1, 1, "mish")
+
+    def forward(self, input):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x1)
+
+        r = self.resblock(x3)
+        x4 = self.conv4(r)
+
+        x4 = torch.cat([x4, x2], dim=1)
+        x5 = self.conv5(x4)
+        return x5
+
+
+class DownSample5(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = Conv_Bn_Activation(512, 1024, 3, 2, "mish")
+        self.conv2 = Conv_Bn_Activation(1024, 512, 1, 1, "mish")
+        self.conv3 = Conv_Bn_Activation(1024, 512, 1, 1, "mish")
+
+        self.resblock = ResBlock(ch=512, nblocks=4)
+        self.conv4 = Conv_Bn_Activation(512, 512, 1, 1, "mish")
+        self.conv5 = Conv_Bn_Activation(1024, 1024, 1, 1, "mish")
+
+    def forward(self, input):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x1)
+
+        r = self.resblock(x3)
+        x4 = self.conv4(r)
+
+        x4 = torch.cat([x4, x2], dim=1)
+        x5 = self.conv5(x4)
+        return x5
+
+
+class Neck(nn.Module):
+    def __init__(self, inference=False):
+        super().__init__()
+        self.inference = inference
+
+        self.conv1 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        self.conv2 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky")
+        self.conv3 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        # SPP
+        self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5 // 2)
+        self.maxpool2 = nn.MaxPool2d(kernel_size=9, stride=1, padding=9 // 2)
+        self.maxpool3 = nn.MaxPool2d(kernel_size=13, stride=1, padding=13 // 2)
+
+        # R -1 -3 -5 -6
+        # SPP
+        self.conv4 = Conv_Bn_Activation(2048, 512, 1, 1, "leaky")
+        self.conv5 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky")
+        self.conv6 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        self.conv7 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        # UP
+        self.upsample1 = Upsample()
+        # R 85
+        self.conv8 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        # R -1 -3
+        self.conv9 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv10 = Conv_Bn_Activation(256, 512, 3, 1, "leaky")
+        self.conv11 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv12 = Conv_Bn_Activation(256, 512, 3, 1, "leaky")
+        self.conv13 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv14 = Conv_Bn_Activation(256, 128, 1, 1, "leaky")
+        # UP
+        self.upsample2 = Upsample()
+        # R 54
+        self.conv15 = Conv_Bn_Activation(256, 128, 1, 1, "leaky")
+        # R -1 -3
+        self.conv16 = Conv_Bn_Activation(256, 128, 1, 1, "leaky")
+        self.conv17 = Conv_Bn_Activation(128, 256, 3, 1, "leaky")
+        self.conv18 = Conv_Bn_Activation(256, 128, 1, 1, "leaky")
+        self.conv19 = Conv_Bn_Activation(128, 256, 3, 1, "leaky")
+        self.conv20 = Conv_Bn_Activation(256, 128, 1, 1, "leaky")
+
+    def forward(self, input, downsample4, downsample3, inference=False):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x2)
+        # SPP
+        m1 = self.maxpool1(x3)
+        m2 = self.maxpool2(x3)
+        m3 = self.maxpool3(x3)
+        spp = torch.cat([m3, m2, m1, x3], dim=1)
+        # SPP end
+        x4 = self.conv4(spp)
+        x5 = self.conv5(x4)
+        x6 = self.conv6(x5)
+        x7 = self.conv7(x6)
+        # UP
+        up = self.upsample1(x7, downsample4.size(), self.inference)
+        # R 85
+        x8 = self.conv8(downsample4)
+        # R -1 -3
+        x8 = torch.cat([x8, up], dim=1)
+
+        x9 = self.conv9(x8)
+        x10 = self.conv10(x9)
+        x11 = self.conv11(x10)
+        x12 = self.conv12(x11)
+        x13 = self.conv13(x12)
+        x14 = self.conv14(x13)
+
+        # UP
+        up = self.upsample2(x14, downsample3.size(), self.inference)
+        # R 54
+        x15 = self.conv15(downsample3)
+        # R -1 -3
+        x15 = torch.cat([x15, up], dim=1)
+
+        x16 = self.conv16(x15)
+        x17 = self.conv17(x16)
+        x18 = self.conv18(x17)
+        x19 = self.conv19(x18)
+        x20 = self.conv20(x19)
+        return x20, x13, x6
+
+
+class Yolov4Head(nn.Module):
+    def __init__(self, output_ch, n_classes, inference=False):
+        super().__init__()
+        self.inference = inference
+
+        self.conv1 = Conv_Bn_Activation(128, 256, 3, 1, "leaky")
+        self.conv2 = Conv_Bn_Activation(
+            256, output_ch, 1, 1, "linear", bn=False, bias=True
+        )
+
+        self.yolo1 = YoloLayer(
+            anchor_mask=[0, 1, 2],
+            num_classes=n_classes,
+            anchors=[
+                12,
+                16,
+                19,
+                36,
+                40,
+                28,
+                36,
+                75,
+                76,
+                55,
+                72,
+                146,
+                142,
+                110,
+                192,
+                243,
+                459,
+                401,
+            ],
+            num_anchors=9,
+            stride=8,
+        )
+
+        # R -4
+        self.conv3 = Conv_Bn_Activation(128, 256, 3, 2, "leaky")
+
+        # R -1 -16
+        self.conv4 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv5 = Conv_Bn_Activation(256, 512, 3, 1, "leaky")
+        self.conv6 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv7 = Conv_Bn_Activation(256, 512, 3, 1, "leaky")
+        self.conv8 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv9 = Conv_Bn_Activation(256, 512, 3, 1, "leaky")
+        self.conv10 = Conv_Bn_Activation(
+            512, output_ch, 1, 1, "linear", bn=False, bias=True
+        )
+
+        self.yolo2 = YoloLayer(
+            anchor_mask=[3, 4, 5],
+            num_classes=n_classes,
+            anchors=[
+                12,
+                16,
+                19,
+                36,
+                40,
+                28,
+                36,
+                75,
+                76,
+                55,
+                72,
+                146,
+                142,
+                110,
+                192,
+                243,
+                459,
+                401,
+            ],
+            num_anchors=9,
+            stride=16,
+        )
+
+        # R -4
+        self.conv11 = Conv_Bn_Activation(256, 512, 3, 2, "leaky")
+
+        # R -1 -37
+        self.conv12 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        self.conv13 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky")
+        self.conv14 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        self.conv15 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky")
+        self.conv16 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        self.conv17 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky")
+        self.conv18 = Conv_Bn_Activation(
+            1024, output_ch, 1, 1, "linear", bn=False, bias=True
+        )
+
+        self.yolo3 = YoloLayer(
+            anchor_mask=[6, 7, 8],
+            num_classes=n_classes,
+            anchors=[
+                12,
+                16,
+                19,
+                36,
+                40,
+                28,
+                36,
+                75,
+                76,
+                55,
+                72,
+                146,
+                142,
+                110,
+                192,
+                243,
+                459,
+                401,
+            ],
+            num_anchors=9,
+            stride=32,
+        )
+
+    def forward(self, input1, input2, input3):
+        x1 = self.conv1(input1)
+        x2 = self.conv2(x1)
+
+        x3 = self.conv3(input1)
+        # R -1 -16
+        x3 = torch.cat([x3, input2], dim=1)
+        x4 = self.conv4(x3)
+        x5 = self.conv5(x4)
+        x6 = self.conv6(x5)
+        x7 = self.conv7(x6)
+        x8 = self.conv8(x7)
+        x9 = self.conv9(x8)
+        x10 = self.conv10(x9)
+
+        # R -4
+        x11 = self.conv11(x8)
+        # R -1 -37
+        x11 = torch.cat([x11, input3], dim=1)
+
+        x12 = self.conv12(x11)
+        x13 = self.conv13(x12)
+        x14 = self.conv14(x13)
+        x15 = self.conv15(x14)
+        x16 = self.conv16(x15)
+        x17 = self.conv17(x16)
+        x18 = self.conv18(x17)
+
+        if self.inference:
+            y1 = self.yolo1(x2)
+            y2 = self.yolo2(x10)
+            y3 = self.yolo3(x18)
+
+            return get_region_boxes([y1, y2, y3])
+
+        else:
+            return [x2, x10, x18]
+
+
+class Yolov4(nn.Module):
+    def __init__(self, n_classes=80, inference=False):
+        super().__init__()
+
+        output_ch = (4 + 1 + n_classes) * 3
+
+        # backbone
+        self.down1 = DownSample1()
+        self.down2 = DownSample2()
+        self.down3 = DownSample3()
+        self.down4 = DownSample4()
+        self.down5 = DownSample5()
+        # neck
+        self.neek = Neck(inference)
+
+        # head
+        self.head = Yolov4Head(output_ch, n_classes, inference)
+
+    def forward(self, input):
+        d1 = self.down1(input)
+        d2 = self.down2(d1)
+        d3 = self.down3(d2)
+        d4 = self.down4(d3)
+        d5 = self.down5(d4)
+
+        x20, x13, x6 = self.neek(d5, d4, d3)
+
+        output = self.head(x20, x13, x6)
+        return output
+
+
+
+
+
+
+
+

Functions

+
+
+def get_region_boxes(boxes_and_confs) +
+
+
+
+ +Expand source code + +
def get_region_boxes(boxes_and_confs):
+    # print('Getting boxes from boxes and confs ...')
+
+    boxes_list = []
+    confs_list = []
+
+    for item in boxes_and_confs:
+        boxes_list.append(item[0])
+        confs_list.append(item[1])
+
+    # boxes: [batch, num1 + num2 + num3, 1, 4]
+    # confs: [batch, num1 + num2 + num3, num_classes]
+    boxes = torch.cat(boxes_list, dim=1)
+    confs = torch.cat(confs_list, dim=1)
+
+    return [boxes, confs]
+
+
+
+
+
+

Classes

+
+
+class Conv_Bn_Activation +(in_channels, out_channels, kernel_size, stride, activation, bn=True, bias=False) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Conv_Bn_Activation(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride,
+        activation,
+        bn=True,
+        bias=False,
+    ):
+        super().__init__()
+        pad = (kernel_size - 1) // 2
+
+        self.conv = nn.ModuleList()
+        if bias:
+            self.conv.append(
+                nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad)
+            )
+        else:
+            self.conv.append(
+                nn.Conv2d(
+                    in_channels, out_channels, kernel_size, stride, pad, bias=False
+                )
+            )
+        if bn:
+            self.conv.append(nn.BatchNorm2d(out_channels))
+        if activation == "mish":
+            self.conv.append(Mish())
+        elif activation == "relu":
+            self.conv.append(nn.ReLU(inplace=True))
+        elif activation == "leaky":
+            self.conv.append(nn.LeakyReLU(0.1, inplace=True))
+        elif activation == "linear":
+            pass
+        else:
+            raise Exception("activation error")
+
+    def forward(self, x):
+        for l in self.conv:
+            x = l(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    for l in self.conv:
+        x = l(x)
+    return x
+
+
+
+
+
+class DownSample1 +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DownSample1(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = Conv_Bn_Activation(3, 32, 3, 1, "mish")
+
+        self.conv2 = Conv_Bn_Activation(32, 64, 3, 2, "mish")
+        self.conv3 = Conv_Bn_Activation(64, 64, 1, 1, "mish")
+        # [route]
+        # layers = -2
+        self.conv4 = Conv_Bn_Activation(64, 64, 1, 1, "mish")
+
+        self.conv5 = Conv_Bn_Activation(64, 32, 1, 1, "mish")
+        self.conv6 = Conv_Bn_Activation(32, 64, 3, 1, "mish")
+        # [shortcut]
+        # from=-3
+        # activation = linear
+
+        self.conv7 = Conv_Bn_Activation(64, 64, 1, 1, "mish")
+        # [route]
+        # layers = -1, -7
+        self.conv8 = Conv_Bn_Activation(128, 64, 1, 1, "mish")
+
+    def forward(self, input):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x2)
+        # route -2
+        x4 = self.conv4(x2)
+        x5 = self.conv5(x4)
+        x6 = self.conv6(x5)
+        # shortcut -3
+        x6 = x6 + x4
+
+        x7 = self.conv7(x6)
+        # [route]
+        # layers = -1, -7
+        x7 = torch.cat([x7, x3], dim=1)
+        x8 = self.conv8(x7)
+        return x8
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, input) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, input):
+    x1 = self.conv1(input)
+    x2 = self.conv2(x1)
+    x3 = self.conv3(x2)
+    # route -2
+    x4 = self.conv4(x2)
+    x5 = self.conv5(x4)
+    x6 = self.conv6(x5)
+    # shortcut -3
+    x6 = x6 + x4
+
+    x7 = self.conv7(x6)
+    # [route]
+    # layers = -1, -7
+    x7 = torch.cat([x7, x3], dim=1)
+    x8 = self.conv8(x7)
+    return x8
+
+
+
+
+
+class DownSample2 +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DownSample2(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = Conv_Bn_Activation(64, 128, 3, 2, "mish")
+        self.conv2 = Conv_Bn_Activation(128, 64, 1, 1, "mish")
+        # r -2
+        self.conv3 = Conv_Bn_Activation(128, 64, 1, 1, "mish")
+
+        self.resblock = ResBlock(ch=64, nblocks=2)
+
+        # s -3
+        self.conv4 = Conv_Bn_Activation(64, 64, 1, 1, "mish")
+        # r -1 -10
+        self.conv5 = Conv_Bn_Activation(128, 128, 1, 1, "mish")
+
+    def forward(self, input):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x1)
+
+        r = self.resblock(x3)
+        x4 = self.conv4(r)
+
+        x4 = torch.cat([x4, x2], dim=1)
+        x5 = self.conv5(x4)
+        return x5
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, input) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, input):
+    x1 = self.conv1(input)
+    x2 = self.conv2(x1)
+    x3 = self.conv3(x1)
+
+    r = self.resblock(x3)
+    x4 = self.conv4(r)
+
+    x4 = torch.cat([x4, x2], dim=1)
+    x5 = self.conv5(x4)
+    return x5
+
+
+
+
+
+class DownSample3 +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DownSample3(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = Conv_Bn_Activation(128, 256, 3, 2, "mish")
+        self.conv2 = Conv_Bn_Activation(256, 128, 1, 1, "mish")
+        self.conv3 = Conv_Bn_Activation(256, 128, 1, 1, "mish")
+
+        self.resblock = ResBlock(ch=128, nblocks=8)
+        self.conv4 = Conv_Bn_Activation(128, 128, 1, 1, "mish")
+        self.conv5 = Conv_Bn_Activation(256, 256, 1, 1, "mish")
+
+    def forward(self, input):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x1)
+
+        r = self.resblock(x3)
+        x4 = self.conv4(r)
+
+        x4 = torch.cat([x4, x2], dim=1)
+        x5 = self.conv5(x4)
+        return x5
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, input) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, input):
+    x1 = self.conv1(input)
+    x2 = self.conv2(x1)
+    x3 = self.conv3(x1)
+
+    r = self.resblock(x3)
+    x4 = self.conv4(r)
+
+    x4 = torch.cat([x4, x2], dim=1)
+    x5 = self.conv5(x4)
+    return x5
+
+
+
+
+
+class DownSample4 +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DownSample4(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = Conv_Bn_Activation(256, 512, 3, 2, "mish")
+        self.conv2 = Conv_Bn_Activation(512, 256, 1, 1, "mish")
+        self.conv3 = Conv_Bn_Activation(512, 256, 1, 1, "mish")
+
+        self.resblock = ResBlock(ch=256, nblocks=8)
+        self.conv4 = Conv_Bn_Activation(256, 256, 1, 1, "mish")
+        self.conv5 = Conv_Bn_Activation(512, 512, 1, 1, "mish")
+
+    def forward(self, input):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x1)
+
+        r = self.resblock(x3)
+        x4 = self.conv4(r)
+
+        x4 = torch.cat([x4, x2], dim=1)
+        x5 = self.conv5(x4)
+        return x5
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, input) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, input):
+    x1 = self.conv1(input)
+    x2 = self.conv2(x1)
+    x3 = self.conv3(x1)
+
+    r = self.resblock(x3)
+    x4 = self.conv4(r)
+
+    x4 = torch.cat([x4, x2], dim=1)
+    x5 = self.conv5(x4)
+    return x5
+
+
+
+
+
+class DownSample5 +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DownSample5(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = Conv_Bn_Activation(512, 1024, 3, 2, "mish")
+        self.conv2 = Conv_Bn_Activation(1024, 512, 1, 1, "mish")
+        self.conv3 = Conv_Bn_Activation(1024, 512, 1, 1, "mish")
+
+        self.resblock = ResBlock(ch=512, nblocks=4)
+        self.conv4 = Conv_Bn_Activation(512, 512, 1, 1, "mish")
+        self.conv5 = Conv_Bn_Activation(1024, 1024, 1, 1, "mish")
+
+    def forward(self, input):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x1)
+
+        r = self.resblock(x3)
+        x4 = self.conv4(r)
+
+        x4 = torch.cat([x4, x2], dim=1)
+        x5 = self.conv5(x4)
+        return x5
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, input) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, input):
+    x1 = self.conv1(input)
+    x2 = self.conv2(x1)
+    x3 = self.conv3(x1)
+
+    r = self.resblock(x3)
+    x4 = self.conv4(r)
+
+    x4 = torch.cat([x4, x2], dim=1)
+    x5 = self.conv5(x4)
+    return x5
+
+
+
+
+
+class Mish +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Mish(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        x = x * (torch.tanh(torch.nn.functional.softplus(x)))
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = x * (torch.tanh(torch.nn.functional.softplus(x)))
+    return x
+
+
+
+
+
+class Neck +(inference=False) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Neck(nn.Module):
+    def __init__(self, inference=False):
+        super().__init__()
+        self.inference = inference
+
+        self.conv1 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        self.conv2 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky")
+        self.conv3 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        # SPP
+        self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5 // 2)
+        self.maxpool2 = nn.MaxPool2d(kernel_size=9, stride=1, padding=9 // 2)
+        self.maxpool3 = nn.MaxPool2d(kernel_size=13, stride=1, padding=13 // 2)
+
+        # R -1 -3 -5 -6
+        # SPP
+        self.conv4 = Conv_Bn_Activation(2048, 512, 1, 1, "leaky")
+        self.conv5 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky")
+        self.conv6 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        self.conv7 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        # UP
+        self.upsample1 = Upsample()
+        # R 85
+        self.conv8 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        # R -1 -3
+        self.conv9 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv10 = Conv_Bn_Activation(256, 512, 3, 1, "leaky")
+        self.conv11 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv12 = Conv_Bn_Activation(256, 512, 3, 1, "leaky")
+        self.conv13 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv14 = Conv_Bn_Activation(256, 128, 1, 1, "leaky")
+        # UP
+        self.upsample2 = Upsample()
+        # R 54
+        self.conv15 = Conv_Bn_Activation(256, 128, 1, 1, "leaky")
+        # R -1 -3
+        self.conv16 = Conv_Bn_Activation(256, 128, 1, 1, "leaky")
+        self.conv17 = Conv_Bn_Activation(128, 256, 3, 1, "leaky")
+        self.conv18 = Conv_Bn_Activation(256, 128, 1, 1, "leaky")
+        self.conv19 = Conv_Bn_Activation(128, 256, 3, 1, "leaky")
+        self.conv20 = Conv_Bn_Activation(256, 128, 1, 1, "leaky")
+
+    def forward(self, input, downsample4, downsample3, inference=False):
+        x1 = self.conv1(input)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x2)
+        # SPP
+        m1 = self.maxpool1(x3)
+        m2 = self.maxpool2(x3)
+        m3 = self.maxpool3(x3)
+        spp = torch.cat([m3, m2, m1, x3], dim=1)
+        # SPP end
+        x4 = self.conv4(spp)
+        x5 = self.conv5(x4)
+        x6 = self.conv6(x5)
+        x7 = self.conv7(x6)
+        # UP
+        up = self.upsample1(x7, downsample4.size(), self.inference)
+        # R 85
+        x8 = self.conv8(downsample4)
+        # R -1 -3
+        x8 = torch.cat([x8, up], dim=1)
+
+        x9 = self.conv9(x8)
+        x10 = self.conv10(x9)
+        x11 = self.conv11(x10)
+        x12 = self.conv12(x11)
+        x13 = self.conv13(x12)
+        x14 = self.conv14(x13)
+
+        # UP
+        up = self.upsample2(x14, downsample3.size(), self.inference)
+        # R 54
+        x15 = self.conv15(downsample3)
+        # R -1 -3
+        x15 = torch.cat([x15, up], dim=1)
+
+        x16 = self.conv16(x15)
+        x17 = self.conv17(x16)
+        x18 = self.conv18(x17)
+        x19 = self.conv19(x18)
+        x20 = self.conv20(x19)
+        return x20, x13, x6
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, input, downsample4, downsample3, inference=False) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, input, downsample4, downsample3, inference=False):
+    x1 = self.conv1(input)
+    x2 = self.conv2(x1)
+    x3 = self.conv3(x2)
+    # SPP
+    m1 = self.maxpool1(x3)
+    m2 = self.maxpool2(x3)
+    m3 = self.maxpool3(x3)
+    spp = torch.cat([m3, m2, m1, x3], dim=1)
+    # SPP end
+    x4 = self.conv4(spp)
+    x5 = self.conv5(x4)
+    x6 = self.conv6(x5)
+    x7 = self.conv7(x6)
+    # UP
+    up = self.upsample1(x7, downsample4.size(), self.inference)
+    # R 85
+    x8 = self.conv8(downsample4)
+    # R -1 -3
+    x8 = torch.cat([x8, up], dim=1)
+
+    x9 = self.conv9(x8)
+    x10 = self.conv10(x9)
+    x11 = self.conv11(x10)
+    x12 = self.conv12(x11)
+    x13 = self.conv13(x12)
+    x14 = self.conv14(x13)
+
+    # UP
+    up = self.upsample2(x14, downsample3.size(), self.inference)
+    # R 54
+    x15 = self.conv15(downsample3)
+    # R -1 -3
+    x15 = torch.cat([x15, up], dim=1)
+
+    x16 = self.conv16(x15)
+    x17 = self.conv17(x16)
+    x18 = self.conv18(x17)
+    x19 = self.conv19(x18)
+    x20 = self.conv20(x19)
+    return x20, x13, x6
+
+
+
+
+
+class ResBlock +(ch, nblocks=1, shortcut=True) +
+
+

Sequential residual blocks each of which consists of +two convolution layers.

+

Args

+
+
ch : int
+
number of input and output channels.
+
nblocks : int
+
number of residual blocks.
+
shortcut : bool
+
if True, residual tensor addition is enabled.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ResBlock(nn.Module):
+    """
+    Sequential residual blocks each of which consists of \
+    two convolution layers.
+    Args:
+        ch (int): number of input and output channels.
+        nblocks (int): number of residual blocks.
+        shortcut (bool): if True, residual tensor addition is enabled.
+    """
+
+    def __init__(self, ch, nblocks=1, shortcut=True):
+        super().__init__()
+        self.shortcut = shortcut
+        self.module_list = nn.ModuleList()
+        for i in range(nblocks):
+            resblock_one = nn.ModuleList()
+            resblock_one.append(Conv_Bn_Activation(ch, ch, 1, 1, "mish"))
+            resblock_one.append(Conv_Bn_Activation(ch, ch, 3, 1, "mish"))
+            self.module_list.append(resblock_one)
+
+    def forward(self, x):
+        for module in self.module_list:
+            h = x
+            for res in module:
+                h = res(h)
+            x = x + h if self.shortcut else h
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    for module in self.module_list:
+        h = x
+        for res in module:
+            h = res(h)
+        x = x + h if self.shortcut else h
+    return x
+
+
+
+
+
+class Upsample +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Upsample(nn.Module):
+    def __init__(self):
+        super(Upsample, self).__init__()
+
+    def forward(self, x, target_size, inference=False):
+        assert x.data.dim() == 4
+        # _, _, tH, tW = target_size
+
+        if inference:
+
+            # B = x.data.size(0)
+            # C = x.data.size(1)
+            # H = x.data.size(2)
+            # W = x.data.size(3)
+
+            return (
+                x.view(x.size(0), x.size(1), x.size(2), 1, x.size(3), 1)
+                .expand(
+                    x.size(0),
+                    x.size(1),
+                    x.size(2),
+                    target_size[2] // x.size(2),
+                    x.size(3),
+                    target_size[3] // x.size(3),
+                )
+                .contiguous()
+                .view(x.size(0), x.size(1), target_size[2], target_size[3])
+            )
+        else:
+            return F.interpolate(
+                x, size=(target_size[2], target_size[3]), mode="nearest"
+            )
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, x, target_size, inference=False) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x, target_size, inference=False):
+    assert x.data.dim() == 4
+    # _, _, tH, tW = target_size
+
+    if inference:
+
+        # B = x.data.size(0)
+        # C = x.data.size(1)
+        # H = x.data.size(2)
+        # W = x.data.size(3)
+
+        return (
+            x.view(x.size(0), x.size(1), x.size(2), 1, x.size(3), 1)
+            .expand(
+                x.size(0),
+                x.size(1),
+                x.size(2),
+                target_size[2] // x.size(2),
+                x.size(3),
+                target_size[3] // x.size(3),
+            )
+            .contiguous()
+            .view(x.size(0), x.size(1), target_size[2], target_size[3])
+        )
+    else:
+        return F.interpolate(
+            x, size=(target_size[2], target_size[3]), mode="nearest"
+        )
+
+
+
+
+
+class Yolov4 +(n_classes=80, inference=False) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Yolov4(nn.Module):
+    def __init__(self, n_classes=80, inference=False):
+        super().__init__()
+
+        output_ch = (4 + 1 + n_classes) * 3
+
+        # backbone
+        self.down1 = DownSample1()
+        self.down2 = DownSample2()
+        self.down3 = DownSample3()
+        self.down4 = DownSample4()
+        self.down5 = DownSample5()
+        # neck
+        self.neek = Neck(inference)
+
+        # head
+        self.head = Yolov4Head(output_ch, n_classes, inference)
+
+    def forward(self, input):
+        d1 = self.down1(input)
+        d2 = self.down2(d1)
+        d3 = self.down3(d2)
+        d4 = self.down4(d3)
+        d5 = self.down5(d4)
+
+        x20, x13, x6 = self.neek(d5, d4, d3)
+
+        output = self.head(x20, x13, x6)
+        return output
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Methods

+
+
+def forward(self, input) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, input):
+    d1 = self.down1(input)
+    d2 = self.down2(d1)
+    d3 = self.down3(d2)
+    d4 = self.down4(d3)
+    d5 = self.down5(d4)
+
+    x20, x13, x6 = self.neek(d5, d4, d3)
+
+    output = self.head(x20, x13, x6)
+    return output
+
+
+
+
+
+class Yolov4Head +(output_ch, n_classes, inference=False) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class Yolov4Head(nn.Module):
+    def __init__(self, output_ch, n_classes, inference=False):
+        super().__init__()
+        self.inference = inference
+
+        self.conv1 = Conv_Bn_Activation(128, 256, 3, 1, "leaky")
+        self.conv2 = Conv_Bn_Activation(
+            256, output_ch, 1, 1, "linear", bn=False, bias=True
+        )
+
+        self.yolo1 = YoloLayer(
+            anchor_mask=[0, 1, 2],
+            num_classes=n_classes,
+            anchors=[
+                12,
+                16,
+                19,
+                36,
+                40,
+                28,
+                36,
+                75,
+                76,
+                55,
+                72,
+                146,
+                142,
+                110,
+                192,
+                243,
+                459,
+                401,
+            ],
+            num_anchors=9,
+            stride=8,
+        )
+
+        # R -4
+        self.conv3 = Conv_Bn_Activation(128, 256, 3, 2, "leaky")
+
+        # R -1 -16
+        self.conv4 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv5 = Conv_Bn_Activation(256, 512, 3, 1, "leaky")
+        self.conv6 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv7 = Conv_Bn_Activation(256, 512, 3, 1, "leaky")
+        self.conv8 = Conv_Bn_Activation(512, 256, 1, 1, "leaky")
+        self.conv9 = Conv_Bn_Activation(256, 512, 3, 1, "leaky")
+        self.conv10 = Conv_Bn_Activation(
+            512, output_ch, 1, 1, "linear", bn=False, bias=True
+        )
+
+        self.yolo2 = YoloLayer(
+            anchor_mask=[3, 4, 5],
+            num_classes=n_classes,
+            anchors=[
+                12,
+                16,
+                19,
+                36,
+                40,
+                28,
+                36,
+                75,
+                76,
+                55,
+                72,
+                146,
+                142,
+                110,
+                192,
+                243,
+                459,
+                401,
+            ],
+            num_anchors=9,
+            stride=16,
+        )
+
+        # R -4
+        self.conv11 = Conv_Bn_Activation(256, 512, 3, 2, "leaky")
+
+        # R -1 -37
+        self.conv12 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        self.conv13 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky")
+        self.conv14 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        self.conv15 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky")
+        self.conv16 = Conv_Bn_Activation(1024, 512, 1, 1, "leaky")
+        self.conv17 = Conv_Bn_Activation(512, 1024, 3, 1, "leaky")
+        self.conv18 = Conv_Bn_Activation(
+            1024, output_ch, 1, 1, "linear", bn=False, bias=True
+        )
+
+        self.yolo3 = YoloLayer(
+            anchor_mask=[6, 7, 8],
+            num_classes=n_classes,
+            anchors=[
+                12,
+                16,
+                19,
+                36,
+                40,
+                28,
+                36,
+                75,
+                76,
+                55,
+                72,
+                146,
+                142,
+                110,
+                192,
+                243,
+                459,
+                401,
+            ],
+            num_anchors=9,
+            stride=32,
+        )
+
+    def forward(self, input1, input2, input3):
+        x1 = self.conv1(input1)
+        x2 = self.conv2(x1)
+
+        x3 = self.conv3(input1)
+        # R -1 -16
+        x3 = torch.cat([x3, input2], dim=1)
+        x4 = self.conv4(x3)
+        x5 = self.conv5(x4)
+        x6 = self.conv6(x5)
+        x7 = self.conv7(x6)
+        x8 = self.conv8(x7)
+        x9 = self.conv9(x8)
+        x10 = self.conv10(x9)
+
+        # R -4
+        x11 = self.conv11(x8)
+        # R -1 -37
+        x11 = torch.cat([x11, input3], dim=1)
+
+        x12 = self.conv12(x11)
+        x13 = self.conv13(x12)
+        x14 = self.conv14(x13)
+        x15 = self.conv15(x14)
+        x16 = self.conv16(x15)
+        x17 = self.conv17(x16)
+        x18 = self.conv18(x17)
+
+        if self.inference:
+            y1 = self.yolo1(x2)
+            y2 = self.yolo2(x10)
+            y3 = self.yolo3(x18)
+
+            return get_region_boxes([y1, y2, y3])
+
+        else:
+            return [x2, x10, x18]
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, input1, input2, input3) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, input1, input2, input3):
+    x1 = self.conv1(input1)
+    x2 = self.conv2(x1)
+
+    x3 = self.conv3(input1)
+    # R -1 -16
+    x3 = torch.cat([x3, input2], dim=1)
+    x4 = self.conv4(x3)
+    x5 = self.conv5(x4)
+    x6 = self.conv6(x5)
+    x7 = self.conv7(x6)
+    x8 = self.conv8(x7)
+    x9 = self.conv9(x8)
+    x10 = self.conv10(x9)
+
+    # R -4
+    x11 = self.conv11(x8)
+    # R -1 -37
+    x11 = torch.cat([x11, input3], dim=1)
+
+    x12 = self.conv12(x11)
+    x13 = self.conv13(x12)
+    x14 = self.conv14(x13)
+    x15 = self.conv15(x14)
+    x16 = self.conv16(x15)
+    x17 = self.conv17(x16)
+    x18 = self.conv18(x17)
+
+    if self.inference:
+        y1 = self.yolo1(x2)
+        y2 = self.yolo2(x10)
+        y3 = self.yolo3(x18)
+
+        return get_region_boxes([y1, y2, y3])
+
+    else:
+        return [x2, x10, x18]
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/yolov4/utils.html b/docs/api/carvekit/ml/arch/yolov4/utils.html new file mode 100644 index 0000000..0f226f6 --- /dev/null +++ b/docs/api/carvekit/ml/arch/yolov4/utils.html @@ -0,0 +1,294 @@ + + + + + + +carvekit.ml.arch.yolov4.utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.yolov4.utils

+
+
+
+ +Expand source code + +
import numpy as np
+
+
+def nms_cpu(boxes, confs, nms_thresh=0.5, min_mode=False):
+    # print(boxes.shape)
+    x1 = boxes[:, 0]
+    y1 = boxes[:, 1]
+    x2 = boxes[:, 2]
+    y2 = boxes[:, 3]
+
+    areas = (x2 - x1) * (y2 - y1)
+    order = confs.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        idx_self = order[0]
+        idx_other = order[1:]
+
+        keep.append(idx_self)
+
+        xx1 = np.maximum(x1[idx_self], x1[idx_other])
+        yy1 = np.maximum(y1[idx_self], y1[idx_other])
+        xx2 = np.minimum(x2[idx_self], x2[idx_other])
+        yy2 = np.minimum(y2[idx_self], y2[idx_other])
+
+        w = np.maximum(0.0, xx2 - xx1)
+        h = np.maximum(0.0, yy2 - yy1)
+        inter = w * h
+
+        if min_mode:
+            over = inter / np.minimum(areas[order[0]], areas[order[1:]])
+        else:
+            over = inter / (areas[order[0]] + areas[order[1:]] - inter)
+
+        inds = np.where(over <= nms_thresh)[0]
+        order = order[inds + 1]
+
+    return np.array(keep)
+
+
+def post_processing(conf_thresh, nms_thresh, output):
+    # anchors = [12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401]
+    # num_anchors = 9
+    # anchor_masks = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
+    # strides = [8, 16, 32]
+    # anchor_step = len(anchors) // num_anchors
+
+    # [batch, num, 1, 4]
+    box_array = output[0]
+    # [batch, num, num_classes]
+    confs = output[1]
+
+    if type(box_array).__name__ != "ndarray":
+        box_array = box_array.cpu().detach().numpy()
+        confs = confs.cpu().detach().numpy()
+
+    num_classes = confs.shape[2]
+
+    # [batch, num, 4]
+    box_array = box_array[:, :, 0]
+
+    # [batch, num, num_classes] --> [batch, num]
+    max_conf = np.max(confs, axis=2)
+    max_id = np.argmax(confs, axis=2)
+
+    bboxes_batch = []
+    for i in range(box_array.shape[0]):
+
+        argwhere = max_conf[i] > conf_thresh
+        l_box_array = box_array[i, argwhere, :]
+        l_max_conf = max_conf[i, argwhere]
+        l_max_id = max_id[i, argwhere]
+
+        bboxes = []
+        # nms for each class
+        for j in range(num_classes):
+
+            cls_argwhere = l_max_id == j
+            ll_box_array = l_box_array[cls_argwhere, :]
+            ll_max_conf = l_max_conf[cls_argwhere]
+            ll_max_id = l_max_id[cls_argwhere]
+
+            keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh)
+
+            if keep.size > 0:
+                ll_box_array = ll_box_array[keep, :]
+                ll_max_conf = ll_max_conf[keep]
+                ll_max_id = ll_max_id[keep]
+
+                for k in range(ll_box_array.shape[0]):
+                    bboxes.append(
+                        [
+                            ll_box_array[k, 0],
+                            ll_box_array[k, 1],
+                            ll_box_array[k, 2],
+                            ll_box_array[k, 3],
+                            ll_max_conf[k],
+                            ll_max_conf[k],
+                            ll_max_id[k],
+                        ]
+                    )
+
+        bboxes_batch.append(bboxes)
+
+    return bboxes_batch
+
+
+
+
+
+
+
+

Functions

+
+
+def nms_cpu(boxes, confs, nms_thresh=0.5, min_mode=False) +
+
+
+
+ +Expand source code + +
def nms_cpu(boxes, confs, nms_thresh=0.5, min_mode=False):
+    # print(boxes.shape)
+    x1 = boxes[:, 0]
+    y1 = boxes[:, 1]
+    x2 = boxes[:, 2]
+    y2 = boxes[:, 3]
+
+    areas = (x2 - x1) * (y2 - y1)
+    order = confs.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        idx_self = order[0]
+        idx_other = order[1:]
+
+        keep.append(idx_self)
+
+        xx1 = np.maximum(x1[idx_self], x1[idx_other])
+        yy1 = np.maximum(y1[idx_self], y1[idx_other])
+        xx2 = np.minimum(x2[idx_self], x2[idx_other])
+        yy2 = np.minimum(y2[idx_self], y2[idx_other])
+
+        w = np.maximum(0.0, xx2 - xx1)
+        h = np.maximum(0.0, yy2 - yy1)
+        inter = w * h
+
+        if min_mode:
+            over = inter / np.minimum(areas[order[0]], areas[order[1:]])
+        else:
+            over = inter / (areas[order[0]] + areas[order[1:]] - inter)
+
+        inds = np.where(over <= nms_thresh)[0]
+        order = order[inds + 1]
+
+    return np.array(keep)
+
+
+
+def post_processing(conf_thresh, nms_thresh, output) +
+
+
+
+ +Expand source code + +
def post_processing(conf_thresh, nms_thresh, output):
+    # anchors = [12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401]
+    # num_anchors = 9
+    # anchor_masks = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
+    # strides = [8, 16, 32]
+    # anchor_step = len(anchors) // num_anchors
+
+    # [batch, num, 1, 4]
+    box_array = output[0]
+    # [batch, num, num_classes]
+    confs = output[1]
+
+    if type(box_array).__name__ != "ndarray":
+        box_array = box_array.cpu().detach().numpy()
+        confs = confs.cpu().detach().numpy()
+
+    num_classes = confs.shape[2]
+
+    # [batch, num, 4]
+    box_array = box_array[:, :, 0]
+
+    # [batch, num, num_classes] --> [batch, num]
+    max_conf = np.max(confs, axis=2)
+    max_id = np.argmax(confs, axis=2)
+
+    bboxes_batch = []
+    for i in range(box_array.shape[0]):
+
+        argwhere = max_conf[i] > conf_thresh
+        l_box_array = box_array[i, argwhere, :]
+        l_max_conf = max_conf[i, argwhere]
+        l_max_id = max_id[i, argwhere]
+
+        bboxes = []
+        # nms for each class
+        for j in range(num_classes):
+
+            cls_argwhere = l_max_id == j
+            ll_box_array = l_box_array[cls_argwhere, :]
+            ll_max_conf = l_max_conf[cls_argwhere]
+            ll_max_id = l_max_id[cls_argwhere]
+
+            keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh)
+
+            if keep.size > 0:
+                ll_box_array = ll_box_array[keep, :]
+                ll_max_conf = ll_max_conf[keep]
+                ll_max_id = ll_max_id[keep]
+
+                for k in range(ll_box_array.shape[0]):
+                    bboxes.append(
+                        [
+                            ll_box_array[k, 0],
+                            ll_box_array[k, 1],
+                            ll_box_array[k, 2],
+                            ll_box_array[k, 3],
+                            ll_max_conf[k],
+                            ll_max_conf[k],
+                            ll_max_id[k],
+                        ]
+                    )
+
+        bboxes_batch.append(bboxes)
+
+    return bboxes_batch
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/arch/yolov4/yolo_layer.html b/docs/api/carvekit/ml/arch/yolov4/yolo_layer.html new file mode 100644 index 0000000..1d4eb26 --- /dev/null +++ b/docs/api/carvekit/ml/arch/yolov4/yolo_layer.html @@ -0,0 +1,984 @@ + + + + + + +carvekit.ml.arch.yolov4.yolo_layer API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.arch.yolov4.yolo_layer

+
+
+

Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +Source url: https://github.com/Tianxiaomo/pytorch-YOLOv4 +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+Source url: https://github.com/Tianxiaomo/pytorch-YOLOv4
+License: Apache License 2.0
+"""
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+def yolo_forward(
+    output,
+    conf_thresh,
+    num_classes,
+    anchors,
+    num_anchors,
+    scale_x_y,
+    only_objectness=1,
+    validation=False,
+):
+    # Output would be invalid if it does not satisfy this assert
+    # assert (output.size(1) == (5 + num_classes) * num_anchors)
+
+    # print(output.size())
+
+    # Slice the second dimension (channel) of output into:
+    # [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ]
+    # And then into
+    # bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ]
+    batch = output.size(0)
+    H = output.size(2)
+    W = output.size(3)
+
+    bxy_list = []
+    bwh_list = []
+    det_confs_list = []
+    cls_confs_list = []
+
+    for i in range(num_anchors):
+        begin = i * (5 + num_classes)
+        end = (i + 1) * (5 + num_classes)
+
+        bxy_list.append(output[:, begin : begin + 2])
+        bwh_list.append(output[:, begin + 2 : begin + 4])
+        det_confs_list.append(output[:, begin + 4 : begin + 5])
+        cls_confs_list.append(output[:, begin + 5 : end])
+
+    # Shape: [batch, num_anchors * 2, H, W]
+    bxy = torch.cat(bxy_list, dim=1)
+    # Shape: [batch, num_anchors * 2, H, W]
+    bwh = torch.cat(bwh_list, dim=1)
+
+    # Shape: [batch, num_anchors, H, W]
+    det_confs = torch.cat(det_confs_list, dim=1)
+    # Shape: [batch, num_anchors * H * W]
+    det_confs = det_confs.view(batch, num_anchors * H * W)
+
+    # Shape: [batch, num_anchors * num_classes, H, W]
+    cls_confs = torch.cat(cls_confs_list, dim=1)
+    # Shape: [batch, num_anchors, num_classes, H * W]
+    cls_confs = cls_confs.view(batch, num_anchors, num_classes, H * W)
+    # Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes]
+    cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(
+        batch, num_anchors * H * W, num_classes
+    )
+
+    # Apply sigmoid(), exp() and softmax() to slices
+    #
+    bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)
+    bwh = torch.exp(bwh)
+    det_confs = torch.sigmoid(det_confs)
+    cls_confs = torch.sigmoid(cls_confs)
+
+    # Prepare C-x, C-y, P-w, P-h (None of them are torch related)
+    grid_x = np.expand_dims(
+        np.expand_dims(
+            np.expand_dims(np.linspace(0, W - 1, W), axis=0).repeat(H, 0), axis=0
+        ),
+        axis=0,
+    )
+    grid_y = np.expand_dims(
+        np.expand_dims(
+            np.expand_dims(np.linspace(0, H - 1, H), axis=1).repeat(W, 1), axis=0
+        ),
+        axis=0,
+    )
+    # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
+    # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
+
+    anchor_w = []
+    anchor_h = []
+    for i in range(num_anchors):
+        anchor_w.append(anchors[i * 2])
+        anchor_h.append(anchors[i * 2 + 1])
+
+    device = None
+    cuda_check = output.is_cuda
+    if cuda_check:
+        device = output.get_device()
+
+    bx_list = []
+    by_list = []
+    bw_list = []
+    bh_list = []
+
+    # Apply C-x, C-y, P-w, P-h
+    for i in range(num_anchors):
+        ii = i * 2
+        # Shape: [batch, 1, H, W]
+        bx = bxy[:, ii : ii + 1] + torch.tensor(
+            grid_x, device=device, dtype=torch.float32
+        )  # grid_x.to(device=device, dtype=torch.float32)
+        # Shape: [batch, 1, H, W]
+        by = bxy[:, ii + 1 : ii + 2] + torch.tensor(
+            grid_y, device=device, dtype=torch.float32
+        )  # grid_y.to(device=device, dtype=torch.float32)
+        # Shape: [batch, 1, H, W]
+        bw = bwh[:, ii : ii + 1] * anchor_w[i]
+        # Shape: [batch, 1, H, W]
+        bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i]
+
+        bx_list.append(bx)
+        by_list.append(by)
+        bw_list.append(bw)
+        bh_list.append(bh)
+
+    ########################################
+    #   Figure out bboxes from slices     #
+    ########################################
+
+    # Shape: [batch, num_anchors, H, W]
+    bx = torch.cat(bx_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    by = torch.cat(by_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    bw = torch.cat(bw_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    bh = torch.cat(bh_list, dim=1)
+
+    # Shape: [batch, 2 * num_anchors, H, W]
+    bx_bw = torch.cat((bx, bw), dim=1)
+    # Shape: [batch, 2 * num_anchors, H, W]
+    by_bh = torch.cat((by, bh), dim=1)
+
+    # normalize coordinates to [0, 1]
+    bx_bw /= W
+    by_bh /= H
+
+    # Shape: [batch, num_anchors * H * W, 1]
+    bx = bx_bw[:, :num_anchors].view(batch, num_anchors * H * W, 1)
+    by = by_bh[:, :num_anchors].view(batch, num_anchors * H * W, 1)
+    bw = bx_bw[:, num_anchors:].view(batch, num_anchors * H * W, 1)
+    bh = by_bh[:, num_anchors:].view(batch, num_anchors * H * W, 1)
+
+    bx1 = bx - bw * 0.5
+    by1 = by - bh * 0.5
+    bx2 = bx1 + bw
+    by2 = by1 + bh
+
+    # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
+    boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(
+        batch, num_anchors * H * W, 1, 4
+    )
+    # boxes = boxes.repeat(1, 1, num_classes, 1)
+
+    # boxes:     [batch, num_anchors * H * W, 1, 4]
+    # cls_confs: [batch, num_anchors * H * W, num_classes]
+    # det_confs: [batch, num_anchors * H * W]
+
+    det_confs = det_confs.view(batch, num_anchors * H * W, 1)
+    confs = cls_confs * det_confs
+
+    # boxes: [batch, num_anchors * H * W, 1, 4]
+    # confs: [batch, num_anchors * H * W, num_classes]
+
+    return boxes, confs
+
+
+def yolo_forward_dynamic(
+    output,
+    conf_thresh,
+    num_classes,
+    anchors,
+    num_anchors,
+    scale_x_y,
+    only_objectness=1,
+    validation=False,
+):
+    # Output would be invalid if it does not satisfy this assert
+    # assert (output.size(1) == (5 + num_classes) * num_anchors)
+
+    # print(output.size())
+
+    # Slice the second dimension (channel) of output into:
+    # [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ]
+    # And then into
+    # bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ]
+    # batch = output.size(0)
+    # H = output.size(2)
+    # W = output.size(3)
+
+    bxy_list = []
+    bwh_list = []
+    det_confs_list = []
+    cls_confs_list = []
+
+    for i in range(num_anchors):
+        begin = i * (5 + num_classes)
+        end = (i + 1) * (5 + num_classes)
+
+        bxy_list.append(output[:, begin : begin + 2])
+        bwh_list.append(output[:, begin + 2 : begin + 4])
+        det_confs_list.append(output[:, begin + 4 : begin + 5])
+        cls_confs_list.append(output[:, begin + 5 : end])
+
+    # Shape: [batch, num_anchors * 2, H, W]
+    bxy = torch.cat(bxy_list, dim=1)
+    # Shape: [batch, num_anchors * 2, H, W]
+    bwh = torch.cat(bwh_list, dim=1)
+
+    # Shape: [batch, num_anchors, H, W]
+    det_confs = torch.cat(det_confs_list, dim=1)
+    # Shape: [batch, num_anchors * H * W]
+    det_confs = det_confs.view(
+        output.size(0), num_anchors * output.size(2) * output.size(3)
+    )
+
+    # Shape: [batch, num_anchors * num_classes, H, W]
+    cls_confs = torch.cat(cls_confs_list, dim=1)
+    # Shape: [batch, num_anchors, num_classes, H * W]
+    cls_confs = cls_confs.view(
+        output.size(0), num_anchors, num_classes, output.size(2) * output.size(3)
+    )
+    # Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes]
+    cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(
+        output.size(0), num_anchors * output.size(2) * output.size(3), num_classes
+    )
+
+    # Apply sigmoid(), exp() and softmax() to slices
+    #
+    bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)
+    bwh = torch.exp(bwh)
+    det_confs = torch.sigmoid(det_confs)
+    cls_confs = torch.sigmoid(cls_confs)
+
+    # Prepare C-x, C-y, P-w, P-h (None of them are torch related)
+    grid_x = np.expand_dims(
+        np.expand_dims(
+            np.expand_dims(
+                np.linspace(0, output.size(3) - 1, output.size(3)), axis=0
+            ).repeat(output.size(2), 0),
+            axis=0,
+        ),
+        axis=0,
+    )
+    grid_y = np.expand_dims(
+        np.expand_dims(
+            np.expand_dims(
+                np.linspace(0, output.size(2) - 1, output.size(2)), axis=1
+            ).repeat(output.size(3), 1),
+            axis=0,
+        ),
+        axis=0,
+    )
+    # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
+    # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
+
+    anchor_w = []
+    anchor_h = []
+    for i in range(num_anchors):
+        anchor_w.append(anchors[i * 2])
+        anchor_h.append(anchors[i * 2 + 1])
+
+    device = None
+    cuda_check = output.is_cuda
+    if cuda_check:
+        device = output.get_device()
+
+    bx_list = []
+    by_list = []
+    bw_list = []
+    bh_list = []
+
+    # Apply C-x, C-y, P-w, P-h
+    for i in range(num_anchors):
+        ii = i * 2
+        # Shape: [batch, 1, H, W]
+        bx = bxy[:, ii : ii + 1] + torch.tensor(
+            grid_x, device=device, dtype=torch.float32
+        )  # grid_x.to(device=device, dtype=torch.float32)
+        # Shape: [batch, 1, H, W]
+        by = bxy[:, ii + 1 : ii + 2] + torch.tensor(
+            grid_y, device=device, dtype=torch.float32
+        )  # grid_y.to(device=device, dtype=torch.float32)
+        # Shape: [batch, 1, H, W]
+        bw = bwh[:, ii : ii + 1] * anchor_w[i]
+        # Shape: [batch, 1, H, W]
+        bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i]
+
+        bx_list.append(bx)
+        by_list.append(by)
+        bw_list.append(bw)
+        bh_list.append(bh)
+
+    ########################################
+    #   Figure out bboxes from slices     #
+    ########################################
+
+    # Shape: [batch, num_anchors, H, W]
+    bx = torch.cat(bx_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    by = torch.cat(by_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    bw = torch.cat(bw_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    bh = torch.cat(bh_list, dim=1)
+
+    # Shape: [batch, 2 * num_anchors, H, W]
+    bx_bw = torch.cat((bx, bw), dim=1)
+    # Shape: [batch, 2 * num_anchors, H, W]
+    by_bh = torch.cat((by, bh), dim=1)
+
+    # normalize coordinates to [0, 1]
+    bx_bw /= output.size(3)
+    by_bh /= output.size(2)
+
+    # Shape: [batch, num_anchors * H * W, 1]
+    bx = bx_bw[:, :num_anchors].view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1
+    )
+    by = by_bh[:, :num_anchors].view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1
+    )
+    bw = bx_bw[:, num_anchors:].view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1
+    )
+    bh = by_bh[:, num_anchors:].view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1
+    )
+
+    bx1 = bx - bw * 0.5
+    by1 = by - bh * 0.5
+    bx2 = bx1 + bw
+    by2 = by1 + bh
+
+    # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
+    boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4
+    )
+    # boxes = boxes.repeat(1, 1, num_classes, 1)
+
+    # boxes:     [batch, num_anchors * H * W, 1, 4]
+    # cls_confs: [batch, num_anchors * H * W, num_classes]
+    # det_confs: [batch, num_anchors * H * W]
+
+    det_confs = det_confs.view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1
+    )
+    confs = cls_confs * det_confs
+
+    # boxes: [batch, num_anchors * H * W, 1, 4]
+    # confs: [batch, num_anchors * H * W, num_classes]
+
+    return boxes, confs
+
+
+class YoloLayer(nn.Module):
+    """Yolo layer
+    model_out: while inference,is post-processing inside or outside the model
+        true:outside
+    """
+
+    def __init__(
+        self,
+        anchor_mask=[],
+        num_classes=0,
+        anchors=[],
+        num_anchors=1,
+        stride=32,
+        model_out=False,
+    ):
+        super(YoloLayer, self).__init__()
+        self.anchor_mask = anchor_mask
+        self.num_classes = num_classes
+        self.anchors = anchors
+        self.num_anchors = num_anchors
+        self.anchor_step = len(anchors) // num_anchors
+        self.coord_scale = 1
+        self.noobject_scale = 1
+        self.object_scale = 5
+        self.class_scale = 1
+        self.thresh = 0.6
+        self.stride = stride
+        self.seen = 0
+        self.scale_x_y = 1
+
+        self.model_out = model_out
+
+    def forward(self, output, target=None):
+        if self.training:
+            return output
+        masked_anchors = []
+        for m in self.anchor_mask:
+            masked_anchors += self.anchors[
+                m * self.anchor_step : (m + 1) * self.anchor_step
+            ]
+        masked_anchors = [anchor / self.stride for anchor in masked_anchors]
+
+        return yolo_forward_dynamic(
+            output,
+            self.thresh,
+            self.num_classes,
+            masked_anchors,
+            len(self.anchor_mask),
+            scale_x_y=self.scale_x_y,
+        )
+
+
+
+
+
+
+
+

Functions

+
+
+def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1, validation=False) +
+
+
+
+ +Expand source code + +
def yolo_forward(
+    output,
+    conf_thresh,
+    num_classes,
+    anchors,
+    num_anchors,
+    scale_x_y,
+    only_objectness=1,
+    validation=False,
+):
+    # Output would be invalid if it does not satisfy this assert
+    # assert (output.size(1) == (5 + num_classes) * num_anchors)
+
+    # print(output.size())
+
+    # Slice the second dimension (channel) of output into:
+    # [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ]
+    # And then into
+    # bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ]
+    batch = output.size(0)
+    H = output.size(2)
+    W = output.size(3)
+
+    bxy_list = []
+    bwh_list = []
+    det_confs_list = []
+    cls_confs_list = []
+
+    for i in range(num_anchors):
+        begin = i * (5 + num_classes)
+        end = (i + 1) * (5 + num_classes)
+
+        bxy_list.append(output[:, begin : begin + 2])
+        bwh_list.append(output[:, begin + 2 : begin + 4])
+        det_confs_list.append(output[:, begin + 4 : begin + 5])
+        cls_confs_list.append(output[:, begin + 5 : end])
+
+    # Shape: [batch, num_anchors * 2, H, W]
+    bxy = torch.cat(bxy_list, dim=1)
+    # Shape: [batch, num_anchors * 2, H, W]
+    bwh = torch.cat(bwh_list, dim=1)
+
+    # Shape: [batch, num_anchors, H, W]
+    det_confs = torch.cat(det_confs_list, dim=1)
+    # Shape: [batch, num_anchors * H * W]
+    det_confs = det_confs.view(batch, num_anchors * H * W)
+
+    # Shape: [batch, num_anchors * num_classes, H, W]
+    cls_confs = torch.cat(cls_confs_list, dim=1)
+    # Shape: [batch, num_anchors, num_classes, H * W]
+    cls_confs = cls_confs.view(batch, num_anchors, num_classes, H * W)
+    # Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes]
+    cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(
+        batch, num_anchors * H * W, num_classes
+    )
+
+    # Apply sigmoid(), exp() and softmax() to slices
+    #
+    bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)
+    bwh = torch.exp(bwh)
+    det_confs = torch.sigmoid(det_confs)
+    cls_confs = torch.sigmoid(cls_confs)
+
+    # Prepare C-x, C-y, P-w, P-h (None of them are torch related)
+    grid_x = np.expand_dims(
+        np.expand_dims(
+            np.expand_dims(np.linspace(0, W - 1, W), axis=0).repeat(H, 0), axis=0
+        ),
+        axis=0,
+    )
+    grid_y = np.expand_dims(
+        np.expand_dims(
+            np.expand_dims(np.linspace(0, H - 1, H), axis=1).repeat(W, 1), axis=0
+        ),
+        axis=0,
+    )
+    # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
+    # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
+
+    anchor_w = []
+    anchor_h = []
+    for i in range(num_anchors):
+        anchor_w.append(anchors[i * 2])
+        anchor_h.append(anchors[i * 2 + 1])
+
+    device = None
+    cuda_check = output.is_cuda
+    if cuda_check:
+        device = output.get_device()
+
+    bx_list = []
+    by_list = []
+    bw_list = []
+    bh_list = []
+
+    # Apply C-x, C-y, P-w, P-h
+    for i in range(num_anchors):
+        ii = i * 2
+        # Shape: [batch, 1, H, W]
+        bx = bxy[:, ii : ii + 1] + torch.tensor(
+            grid_x, device=device, dtype=torch.float32
+        )  # grid_x.to(device=device, dtype=torch.float32)
+        # Shape: [batch, 1, H, W]
+        by = bxy[:, ii + 1 : ii + 2] + torch.tensor(
+            grid_y, device=device, dtype=torch.float32
+        )  # grid_y.to(device=device, dtype=torch.float32)
+        # Shape: [batch, 1, H, W]
+        bw = bwh[:, ii : ii + 1] * anchor_w[i]
+        # Shape: [batch, 1, H, W]
+        bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i]
+
+        bx_list.append(bx)
+        by_list.append(by)
+        bw_list.append(bw)
+        bh_list.append(bh)
+
+    ########################################
+    #   Figure out bboxes from slices     #
+    ########################################
+
+    # Shape: [batch, num_anchors, H, W]
+    bx = torch.cat(bx_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    by = torch.cat(by_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    bw = torch.cat(bw_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    bh = torch.cat(bh_list, dim=1)
+
+    # Shape: [batch, 2 * num_anchors, H, W]
+    bx_bw = torch.cat((bx, bw), dim=1)
+    # Shape: [batch, 2 * num_anchors, H, W]
+    by_bh = torch.cat((by, bh), dim=1)
+
+    # normalize coordinates to [0, 1]
+    bx_bw /= W
+    by_bh /= H
+
+    # Shape: [batch, num_anchors * H * W, 1]
+    bx = bx_bw[:, :num_anchors].view(batch, num_anchors * H * W, 1)
+    by = by_bh[:, :num_anchors].view(batch, num_anchors * H * W, 1)
+    bw = bx_bw[:, num_anchors:].view(batch, num_anchors * H * W, 1)
+    bh = by_bh[:, num_anchors:].view(batch, num_anchors * H * W, 1)
+
+    bx1 = bx - bw * 0.5
+    by1 = by - bh * 0.5
+    bx2 = bx1 + bw
+    by2 = by1 + bh
+
+    # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
+    boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(
+        batch, num_anchors * H * W, 1, 4
+    )
+    # boxes = boxes.repeat(1, 1, num_classes, 1)
+
+    # boxes:     [batch, num_anchors * H * W, 1, 4]
+    # cls_confs: [batch, num_anchors * H * W, num_classes]
+    # det_confs: [batch, num_anchors * H * W]
+
+    det_confs = det_confs.view(batch, num_anchors * H * W, 1)
+    confs = cls_confs * det_confs
+
+    # boxes: [batch, num_anchors * H * W, 1, 4]
+    # confs: [batch, num_anchors * H * W, num_classes]
+
+    return boxes, confs
+
+
+
+def yolo_forward_dynamic(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1, validation=False) +
+
+
+
+ +Expand source code + +
def yolo_forward_dynamic(
+    output,
+    conf_thresh,
+    num_classes,
+    anchors,
+    num_anchors,
+    scale_x_y,
+    only_objectness=1,
+    validation=False,
+):
+    # Output would be invalid if it does not satisfy this assert
+    # assert (output.size(1) == (5 + num_classes) * num_anchors)
+
+    # print(output.size())
+
+    # Slice the second dimension (channel) of output into:
+    # [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ]
+    # And then into
+    # bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ]
+    # batch = output.size(0)
+    # H = output.size(2)
+    # W = output.size(3)
+
+    bxy_list = []
+    bwh_list = []
+    det_confs_list = []
+    cls_confs_list = []
+
+    for i in range(num_anchors):
+        begin = i * (5 + num_classes)
+        end = (i + 1) * (5 + num_classes)
+
+        bxy_list.append(output[:, begin : begin + 2])
+        bwh_list.append(output[:, begin + 2 : begin + 4])
+        det_confs_list.append(output[:, begin + 4 : begin + 5])
+        cls_confs_list.append(output[:, begin + 5 : end])
+
+    # Shape: [batch, num_anchors * 2, H, W]
+    bxy = torch.cat(bxy_list, dim=1)
+    # Shape: [batch, num_anchors * 2, H, W]
+    bwh = torch.cat(bwh_list, dim=1)
+
+    # Shape: [batch, num_anchors, H, W]
+    det_confs = torch.cat(det_confs_list, dim=1)
+    # Shape: [batch, num_anchors * H * W]
+    det_confs = det_confs.view(
+        output.size(0), num_anchors * output.size(2) * output.size(3)
+    )
+
+    # Shape: [batch, num_anchors * num_classes, H, W]
+    cls_confs = torch.cat(cls_confs_list, dim=1)
+    # Shape: [batch, num_anchors, num_classes, H * W]
+    cls_confs = cls_confs.view(
+        output.size(0), num_anchors, num_classes, output.size(2) * output.size(3)
+    )
+    # Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes]
+    cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(
+        output.size(0), num_anchors * output.size(2) * output.size(3), num_classes
+    )
+
+    # Apply sigmoid(), exp() and softmax() to slices
+    #
+    bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)
+    bwh = torch.exp(bwh)
+    det_confs = torch.sigmoid(det_confs)
+    cls_confs = torch.sigmoid(cls_confs)
+
+    # Prepare C-x, C-y, P-w, P-h (None of them are torch related)
+    grid_x = np.expand_dims(
+        np.expand_dims(
+            np.expand_dims(
+                np.linspace(0, output.size(3) - 1, output.size(3)), axis=0
+            ).repeat(output.size(2), 0),
+            axis=0,
+        ),
+        axis=0,
+    )
+    grid_y = np.expand_dims(
+        np.expand_dims(
+            np.expand_dims(
+                np.linspace(0, output.size(2) - 1, output.size(2)), axis=1
+            ).repeat(output.size(3), 1),
+            axis=0,
+        ),
+        axis=0,
+    )
+    # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
+    # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
+
+    anchor_w = []
+    anchor_h = []
+    for i in range(num_anchors):
+        anchor_w.append(anchors[i * 2])
+        anchor_h.append(anchors[i * 2 + 1])
+
+    device = None
+    cuda_check = output.is_cuda
+    if cuda_check:
+        device = output.get_device()
+
+    bx_list = []
+    by_list = []
+    bw_list = []
+    bh_list = []
+
+    # Apply C-x, C-y, P-w, P-h
+    for i in range(num_anchors):
+        ii = i * 2
+        # Shape: [batch, 1, H, W]
+        bx = bxy[:, ii : ii + 1] + torch.tensor(
+            grid_x, device=device, dtype=torch.float32
+        )  # grid_x.to(device=device, dtype=torch.float32)
+        # Shape: [batch, 1, H, W]
+        by = bxy[:, ii + 1 : ii + 2] + torch.tensor(
+            grid_y, device=device, dtype=torch.float32
+        )  # grid_y.to(device=device, dtype=torch.float32)
+        # Shape: [batch, 1, H, W]
+        bw = bwh[:, ii : ii + 1] * anchor_w[i]
+        # Shape: [batch, 1, H, W]
+        bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i]
+
+        bx_list.append(bx)
+        by_list.append(by)
+        bw_list.append(bw)
+        bh_list.append(bh)
+
+    ########################################
+    #   Figure out bboxes from slices     #
+    ########################################
+
+    # Shape: [batch, num_anchors, H, W]
+    bx = torch.cat(bx_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    by = torch.cat(by_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    bw = torch.cat(bw_list, dim=1)
+    # Shape: [batch, num_anchors, H, W]
+    bh = torch.cat(bh_list, dim=1)
+
+    # Shape: [batch, 2 * num_anchors, H, W]
+    bx_bw = torch.cat((bx, bw), dim=1)
+    # Shape: [batch, 2 * num_anchors, H, W]
+    by_bh = torch.cat((by, bh), dim=1)
+
+    # normalize coordinates to [0, 1]
+    bx_bw /= output.size(3)
+    by_bh /= output.size(2)
+
+    # Shape: [batch, num_anchors * H * W, 1]
+    bx = bx_bw[:, :num_anchors].view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1
+    )
+    by = by_bh[:, :num_anchors].view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1
+    )
+    bw = bx_bw[:, num_anchors:].view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1
+    )
+    bh = by_bh[:, num_anchors:].view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1
+    )
+
+    bx1 = bx - bw * 0.5
+    by1 = by - bh * 0.5
+    bx2 = bx1 + bw
+    by2 = by1 + bh
+
+    # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
+    boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4
+    )
+    # boxes = boxes.repeat(1, 1, num_classes, 1)
+
+    # boxes:     [batch, num_anchors * H * W, 1, 4]
+    # cls_confs: [batch, num_anchors * H * W, num_classes]
+    # det_confs: [batch, num_anchors * H * W]
+
+    det_confs = det_confs.view(
+        output.size(0), num_anchors * output.size(2) * output.size(3), 1
+    )
+    confs = cls_confs * det_confs
+
+    # boxes: [batch, num_anchors * H * W, 1, 4]
+    # confs: [batch, num_anchors * H * W, num_classes]
+
+    return boxes, confs
+
+
+
+
+
+

Classes

+
+
+class YoloLayer +(anchor_mask=[], num_classes=0, anchors=[], num_anchors=1, stride=32, model_out=False) +
+
+

Yolo layer +model_out: while inference,is post-processing inside or outside the model +true:outside

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class YoloLayer(nn.Module):
+    """Yolo layer
+    model_out: while inference,is post-processing inside or outside the model
+        true:outside
+    """
+
+    def __init__(
+        self,
+        anchor_mask=[],
+        num_classes=0,
+        anchors=[],
+        num_anchors=1,
+        stride=32,
+        model_out=False,
+    ):
+        super(YoloLayer, self).__init__()
+        self.anchor_mask = anchor_mask
+        self.num_classes = num_classes
+        self.anchors = anchors
+        self.num_anchors = num_anchors
+        self.anchor_step = len(anchors) // num_anchors
+        self.coord_scale = 1
+        self.noobject_scale = 1
+        self.object_scale = 5
+        self.class_scale = 1
+        self.thresh = 0.6
+        self.stride = stride
+        self.seen = 0
+        self.scale_x_y = 1
+
+        self.model_out = model_out
+
+    def forward(self, output, target=None):
+        if self.training:
+            return output
+        masked_anchors = []
+        for m in self.anchor_mask:
+            masked_anchors += self.anchors[
+                m * self.anchor_step : (m + 1) * self.anchor_step
+            ]
+        masked_anchors = [anchor / self.stride for anchor in masked_anchors]
+
+        return yolo_forward_dynamic(
+            output,
+            self.thresh,
+            self.num_classes,
+            masked_anchors,
+            len(self.anchor_mask),
+            scale_x_y=self.scale_x_y,
+        )
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Methods

+
+
+def forward(self, output, target=None) ‑>Β Callable[...,Β Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, output, target=None):
+    if self.training:
+        return output
+    masked_anchors = []
+    for m in self.anchor_mask:
+        masked_anchors += self.anchors[
+            m * self.anchor_step : (m + 1) * self.anchor_step
+        ]
+    masked_anchors = [anchor / self.stride for anchor in masked_anchors]
+
+    return yolo_forward_dynamic(
+        output,
+        self.thresh,
+        self.num_classes,
+        masked_anchors,
+        len(self.anchor_mask),
+        scale_x_y=self.scale_x_y,
+    )
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/files/index.html b/docs/api/carvekit/ml/files/index.html new file mode 100644 index 0000000..6bf5e0b --- /dev/null +++ b/docs/api/carvekit/ml/files/index.html @@ -0,0 +1,77 @@ + + + + + + +carvekit.ml.files API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.files

+
+
+
+ +Expand source code + +
from pathlib import Path
+
+carvekit_dir = Path.home().joinpath(".cache/carvekit")
+
+carvekit_dir.mkdir(parents=True, exist_ok=True)
+
+checkpoints_dir = carvekit_dir.joinpath("checkpoints")
+
+
+
+

Sub-modules

+
+
carvekit.ml.files.models_loc
+
+ +
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/files/models_loc.html b/docs/api/carvekit/ml/files/models_loc.html new file mode 100644 index 0000000..24e43a5 --- /dev/null +++ b/docs/api/carvekit/ml/files/models_loc.html @@ -0,0 +1,417 @@ + + + + + + +carvekit.ml.files.models_loc API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.files.models_loc

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+import pathlib
+from carvekit.ml.files import checkpoints_dir
+from carvekit.utils.download_models import downloader
+
+
+def u2net_full_pretrained() -> pathlib.Path:
+    """Returns u2net pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("u2net.pth")
+
+
+def basnet_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("basnet.pth")
+
+
+def deeplab_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("deeplab.pth")
+
+
+def fba_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("fba_matting.pth")
+
+
+def tracer_b7_pretrained() -> pathlib.Path:
+    """Returns TRACER with EfficientNet v1 b7 encoder pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("tracer_b7.pth")
+
+
+def scene_classifier_pretrained() -> pathlib.Path:
+    """Returns scene classifier pretrained model location
+    This model is used to classify scenes into 3 categories: hard, soft, digital
+
+    hard - scenes with hard edges, such as objects, buildings, etc.
+    soft - scenes with soft edges, such as portraits, hairs, animal, etc.
+    digital - digital scenes, such as screenshots, graphics, etc.
+
+    more info: https://huggingface.co/Carve/scene_classifier
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("scene_classifier.pth")
+
+
+def yolov4_coco_pretrained() -> pathlib.Path:
+    """Returns yolov4 classifier pretrained model location
+    This model is used to classify objects in images.
+
+    Training dataset: COCO 2017
+    Training classes: 80
+
+    It's a modified version of the original model from https://github.com/Tianxiaomo/pytorch-YOLOv4 (pytorch)
+    We have only added coco classnames to the model.
+
+    Returns:
+        pathlib.Path to model location
+    """
+    return downloader("yolov4_coco_with_classes.pth")
+
+
+def cascadepsp_pretrained() -> pathlib.Path:
+    """Returns cascade psp pretrained model location
+    This model is used to refine segmentation masks.
+
+    Training dataset: MSRA-10K, DUT-OMRON, ECSSD and FSS-1000
+    more info: https://huggingface.co/Carve/cascadepsp
+
+    Returns:
+        pathlib.Path to model location
+    """
+    return downloader("cascadepsp.pth")
+
+
+def download_all():
+    u2net_full_pretrained()
+    fba_pretrained()
+    deeplab_pretrained()
+    basnet_pretrained()
+    tracer_b7_pretrained()
+    scene_classifier_pretrained()
+    yolov4_coco_pretrained()
+    cascadepsp_pretrained()
+
+
+
+
+
+
+
+

Functions

+
+
+def basnet_pretrained() ‑>Β pathlib.Path +
+
+

Returns basnet pretrained model location

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def basnet_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("basnet.pth")
+
+
+
+def cascadepsp_pretrained() ‑>Β pathlib.Path +
+
+

Returns cascade psp pretrained model location +This model is used to refine segmentation masks.

+

Training dataset: MSRA-10K, DUT-OMRON, ECSSD and FSS-1000 +more info: https://huggingface.co/Carve/cascadepsp

+

Returns

+

pathlib.Path to model location

+
+ +Expand source code + +
def cascadepsp_pretrained() -> pathlib.Path:
+    """Returns cascade psp pretrained model location
+    This model is used to refine segmentation masks.
+
+    Training dataset: MSRA-10K, DUT-OMRON, ECSSD and FSS-1000
+    more info: https://huggingface.co/Carve/cascadepsp
+
+    Returns:
+        pathlib.Path to model location
+    """
+    return downloader("cascadepsp.pth")
+
+
+
+def deeplab_pretrained() ‑>Β pathlib.Path +
+
+

Returns basnet pretrained model location

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def deeplab_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("deeplab.pth")
+
+
+
+def download_all() +
+
+
+
+ +Expand source code + +
def download_all():
+    u2net_full_pretrained()
+    fba_pretrained()
+    deeplab_pretrained()
+    basnet_pretrained()
+    tracer_b7_pretrained()
+    scene_classifier_pretrained()
+    yolov4_coco_pretrained()
+    cascadepsp_pretrained()
+
+
+
+def fba_pretrained() ‑>Β pathlib.Path +
+
+

Returns basnet pretrained model location

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def fba_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("fba_matting.pth")
+
+
+
+def scene_classifier_pretrained() ‑>Β pathlib.Path +
+
+

Returns scene classifier pretrained model location +This model is used to classify scenes into 3 categories: hard, soft, digital

+

hard - scenes with hard edges, such as objects, buildings, etc. +soft - scenes with soft edges, such as portraits, hairs, animal, etc. +digital - digital scenes, such as screenshots, graphics, etc.

+

more info: https://huggingface.co/Carve/scene_classifier

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def scene_classifier_pretrained() -> pathlib.Path:
+    """Returns scene classifier pretrained model location
+    This model is used to classify scenes into 3 categories: hard, soft, digital
+
+    hard - scenes with hard edges, such as objects, buildings, etc.
+    soft - scenes with soft edges, such as portraits, hairs, animal, etc.
+    digital - digital scenes, such as screenshots, graphics, etc.
+
+    more info: https://huggingface.co/Carve/scene_classifier
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("scene_classifier.pth")
+
+
+
+def tracer_b7_pretrained() ‑>Β pathlib.Path +
+
+

Returns TRACER with EfficientNet v1 b7 encoder pretrained model location

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def tracer_b7_pretrained() -> pathlib.Path:
+    """Returns TRACER with EfficientNet v1 b7 encoder pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("tracer_b7.pth")
+
+
+
+def u2net_full_pretrained() ‑>Β pathlib.Path +
+
+

Returns u2net pretrained model location

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def u2net_full_pretrained() -> pathlib.Path:
+    """Returns u2net pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("u2net.pth")
+
+
+
+def yolov4_coco_pretrained() ‑>Β pathlib.Path +
+
+

Returns yolov4 classifier pretrained model location +This model is used to classify objects in images.

+

Training dataset: COCO 2017 +Training classes: 80

+

It's a modified version of the original model from https://github.com/Tianxiaomo/pytorch-YOLOv4 (pytorch) +We have only added coco classnames to the model.

+

Returns

+

pathlib.Path to model location

+
+ +Expand source code + +
def yolov4_coco_pretrained() -> pathlib.Path:
+    """Returns yolov4 classifier pretrained model location
+    This model is used to classify objects in images.
+
+    Training dataset: COCO 2017
+    Training classes: 80
+
+    It's a modified version of the original model from https://github.com/Tianxiaomo/pytorch-YOLOv4 (pytorch)
+    We have only added coco classnames to the model.
+
+    Returns:
+        pathlib.Path to model location
+    """
+    return downloader("yolov4_coco_with_classes.pth")
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/index.html b/docs/api/carvekit/ml/index.html new file mode 100644 index 0000000..b09ab9d --- /dev/null +++ b/docs/api/carvekit/ml/index.html @@ -0,0 +1,84 @@ + + + + + + +carvekit.ml API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml

+
+
+
+ +Expand source code + +
from carvekit.utils.models_utils import fix_seed, suppress_warnings
+
+fix_seed()
+suppress_warnings()
+
+
+
+

Sub-modules

+
+
carvekit.ml.arch
+
+
+
+
carvekit.ml.files
+
+
+
+
carvekit.ml.wrap
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/wrap/basnet.html b/docs/api/carvekit/ml/wrap/basnet.html new file mode 100644 index 0000000..fa3f82e --- /dev/null +++ b/docs/api/carvekit/ml/wrap/basnet.html @@ -0,0 +1,474 @@ + + + + + + +carvekit.ml.wrap.basnet API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.wrap.basnet

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+import pathlib
+from typing import Union, List
+
+import PIL
+import numpy as np
+import torch
+from PIL import Image
+
+from carvekit.ml.arch.basnet.basnet import BASNet
+from carvekit.ml.files.models_loc import basnet_pretrained
+from carvekit.utils.image_utils import convert_image, load_image
+from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
+
+__all__ = ["BASNET"]
+
+
+class BASNET(BASNet):
+    """BASNet model interface"""
+
+    def __init__(
+        self,
+        device="cpu",
+        input_image_size: Union[List[int], int] = 320,
+        batch_size: int = 10,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the BASNET model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (Union[List[int], int], default=320): input image size
+            batch_size (int, default=10): the number of images that the neural network processes in one run
+            load_pretrained (bool, default=True): loading pretrained model
+            fp16 (bool, default=True): use fp16 precision **not supported at this moment**
+        """
+        super(BASNET, self).__init__(n_channels=3, n_classes=1)
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.to(device)
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(basnet_pretrained(), map_location=self.device)
+            )
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.Tensor: input for neural network
+
+        """
+        resized = data.resize(self.input_image_size)
+        # noinspection PyTypeChecker
+        resized_arr = np.array(resized, dtype=np.float64)
+        temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+        if np.max(resized_arr) != 0:
+            resized_arr /= np.max(resized_arr)
+        temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+        temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+        temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+        temp_image = temp_image.transpose((2, 0, 1))
+        temp_image = np.expand_dims(temp_image, 0)
+        return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        data = data.unsqueeze(0)
+        mask = data[:, 0, :, :]
+        ma = torch.max(mask)  # Normalizes prediction
+        mi = torch.min(mask)
+        predict = ((mask - mi) / (ma - mi)).squeeze()
+        predict_np = predict.cpu().data.numpy() * 255
+        mask = Image.fromarray(predict_np).convert("L")
+        mask = mask.resize(original_image.size, resample=3)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images through neural network and returns segmentation masks as `PIL.Image.Image` instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images, as `PIL.Image.Image` instances
+
+        """
+        collect_masks = []
+        for image_batch in batch_generator(images, self.batch_size):
+            converted_images = thread_pool_processing(
+                lambda x: convert_image(load_image(x)), image_batch
+            )
+            batches = torch.vstack(
+                thread_pool_processing(self.data_preprocessing, converted_images)
+            )
+            with torch.no_grad():
+                batches = batches.to(self.device)
+                masks, d2, d3, d4, d5, d6, d7, d8 = super(BASNET, self).__call__(
+                    batches
+                )
+                masks_cpu = masks.cpu()
+                del d2, d3, d4, d5, d6, d7, d8, batches, masks
+            masks = thread_pool_processing(
+                lambda x: self.data_postprocessing(masks_cpu[x], converted_images[x]),
+                range(len(converted_images)),
+            )
+            collect_masks += masks
+        return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class BASNET +(device='cpu', input_image_size:Β Union[List[int],Β int]Β =Β 320, batch_size:Β intΒ =Β 10, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False) +
+
+

BASNet model interface

+

Initialize the BASNET model

+

Args

+
+
device : Literal[cpu, cuda], default=cpu
+
processing device
+
input_image_size : Union[List[int], int], default=320
+
input image size
+
batch_size : int, default=10
+
the number of images that the neural network processes in one run
+
load_pretrained : bool, default=True
+
loading pretrained model
+
fp16 : bool, default=True
+
use fp16 precision not supported at this moment
+
+
+ +Expand source code + +
class BASNET(BASNet):
+    """BASNet model interface"""
+
+    def __init__(
+        self,
+        device="cpu",
+        input_image_size: Union[List[int], int] = 320,
+        batch_size: int = 10,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the BASNET model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (Union[List[int], int], default=320): input image size
+            batch_size (int, default=10): the number of images that the neural network processes in one run
+            load_pretrained (bool, default=True): loading pretrained model
+            fp16 (bool, default=True): use fp16 precision **not supported at this moment**
+        """
+        super(BASNET, self).__init__(n_channels=3, n_classes=1)
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.to(device)
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(basnet_pretrained(), map_location=self.device)
+            )
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.Tensor: input for neural network
+
+        """
+        resized = data.resize(self.input_image_size)
+        # noinspection PyTypeChecker
+        resized_arr = np.array(resized, dtype=np.float64)
+        temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+        if np.max(resized_arr) != 0:
+            resized_arr /= np.max(resized_arr)
+        temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+        temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+        temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+        temp_image = temp_image.transpose((2, 0, 1))
+        temp_image = np.expand_dims(temp_image, 0)
+        return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        data = data.unsqueeze(0)
+        mask = data[:, 0, :, :]
+        ma = torch.max(mask)  # Normalizes prediction
+        mi = torch.min(mask)
+        predict = ((mask - mi) / (ma - mi)).squeeze()
+        predict_np = predict.cpu().data.numpy() * 255
+        mask = Image.fromarray(predict_np).convert("L")
+        mask = mask.resize(original_image.size, resample=3)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images through neural network and returns segmentation masks as `PIL.Image.Image` instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images, as `PIL.Image.Image` instances
+
+        """
+        collect_masks = []
+        for image_batch in batch_generator(images, self.batch_size):
+            converted_images = thread_pool_processing(
+                lambda x: convert_image(load_image(x)), image_batch
+            )
+            batches = torch.vstack(
+                thread_pool_processing(self.data_preprocessing, converted_images)
+            )
+            with torch.no_grad():
+                batches = batches.to(self.device)
+                masks, d2, d3, d4, d5, d6, d7, d8 = super(BASNET, self).__call__(
+                    batches
+                )
+                masks_cpu = masks.cpu()
+                del d2, d3, d4, d5, d6, d7, d8, batches, masks
+            masks = thread_pool_processing(
+                lambda x: self.data_postprocessing(masks_cpu[x], converted_images[x]),
+                range(len(converted_images)),
+            )
+            collect_masks += masks
+        return collect_masks
+
+

Ancestors

+
    +
  • BASNet
  • +
  • torch.nn.modules.module.Module
  • +
+

Static methods

+
+
+def data_postprocessing(data:Β torch.Tensor, original_image:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data : torch.Tensor
+
output data from neural network
+
original_image : PIL.Image.Image
+
input image which was used for predicted data
+
+

Returns

+
+
PIL.Image.Image
+
Segmentation mask as PIL Image instance
+
+
+ +Expand source code + +
@staticmethod
+def data_postprocessing(
+    data: torch.Tensor, original_image: PIL.Image.Image
+) -> PIL.Image.Image:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data (torch.Tensor): output data from neural network
+        original_image (PIL.Image.Image): input image which was used for predicted data
+
+    Returns:
+        PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+    """
+    data = data.unsqueeze(0)
+    mask = data[:, 0, :, :]
+    ma = torch.max(mask)  # Normalizes prediction
+    mi = torch.min(mask)
+    predict = ((mask - mi) / (ma - mi)).squeeze()
+    predict_np = predict.cpu().data.numpy() * 255
+    mask = Image.fromarray(predict_np).convert("L")
+    mask = mask.resize(original_image.size, resample=3)
+    return mask
+
+
+
+

Methods

+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.Tensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data : PIL.Image.Image
+
input image
+
+

Returns

+
+
torch.Tensor
+
input for neural network
+
+
+ +Expand source code + +
def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data (PIL.Image.Image): input image
+
+    Returns:
+        torch.Tensor: input for neural network
+
+    """
+    resized = data.resize(self.input_image_size)
+    # noinspection PyTypeChecker
+    resized_arr = np.array(resized, dtype=np.float64)
+    temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+    if np.max(resized_arr) != 0:
+        resized_arr /= np.max(resized_arr)
+    temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+    temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+    temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+    temp_image = temp_image.transpose((2, 0, 1))
+    temp_image = np.expand_dims(temp_image, 0)
+    return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/wrap/cascadepsp.html b/docs/api/carvekit/ml/wrap/cascadepsp.html new file mode 100644 index 0000000..231db25 --- /dev/null +++ b/docs/api/carvekit/ml/wrap/cascadepsp.html @@ -0,0 +1,907 @@ + + + + + + +carvekit.ml.wrap.cascadepsp API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.wrap.cascadepsp

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import pathlib
+import warnings
+
+import PIL
+import numpy as np
+import torch
+from PIL import Image
+from torchvision import transforms
+from typing import Union, List
+
+from carvekit.ml.arch.cascadepsp.pspnet import RefinementModule
+from carvekit.ml.arch.cascadepsp.utils import (
+    process_im_single_pass,
+    process_high_res_im,
+)
+from carvekit.ml.files.models_loc import cascadepsp_pretrained
+from carvekit.utils.image_utils import convert_image, load_image
+from carvekit.utils.models_utils import get_precision_autocast, cast_network
+from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
+
+__all__ = ["CascadePSP"]
+
+
+class CascadePSP(RefinementModule):
+    """
+    CascadePSP to refine the mask from segmentation network
+    """
+
+    def __init__(
+        self,
+        device="cpu",
+        input_tensor_size: int = 900,
+        batch_size: int = 1,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+        mask_binary_threshold=127,
+        global_step_only=False,
+        processing_accelerate_image_size=2048,
+    ):
+        """
+        Initialize the CascadePSP model
+
+        Args:
+            device: processing device
+            input_tensor_size: input image size
+            batch_size: the number of images that the neural network processes in one run
+            load_pretrained: loading pretrained model
+            fp16: use half precision
+            global_step_only: if True, only global step will be used for prediction. See paper for details.
+            mask_binary_threshold: threshold for binary mask, default 70, set to 0 for no threshold
+            processing_accelerate_image_size: thumbnail size for image processing acceleration. Set to 0 to disable
+
+        """
+        super().__init__()
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        self.mask_binary_threshold = mask_binary_threshold
+        self.global_step_only = global_step_only
+        self.processing_accelerate_image_size = processing_accelerate_image_size
+        self.input_tensor_size = input_tensor_size
+
+        self.to(device)
+        if batch_size > 1:
+            warnings.warn(
+                "Batch size > 1 is experimental feature for CascadePSP."
+                " Please, don't use it if you have GPU with small memory!"
+            )
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(cascadepsp_pretrained(), map_location=self.device)
+            )
+        self.eval()
+
+        self._image_transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+
+        self._seg_transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(mean=[0.5], std=[0.5]),
+            ]
+        )
+
+    def data_preprocessing(self, data: Union[PIL.Image.Image]) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data: input image
+
+        Returns:
+            input for neural network
+
+        """
+        preprocessed_data = data.copy()
+        if self.batch_size == 1 and self.processing_accelerate_image_size > 0:
+            # Okay, we have only one image, so
+            # we can use image processing acceleration for accelerate high resolution image processing
+            preprocessed_data.thumbnail(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+        elif self.batch_size == 1:
+            pass  # No need to do anything
+        elif self.batch_size > 1 and self.global_step_only is True:
+            # If we have more than one image and we use only global step,
+            # there aren't any reason to use image processing acceleration,
+            # because we will use only global step for prediction and anyway it will be resized to input_tensor_size
+            preprocessed_data = preprocessed_data.resize(
+                (self.input_tensor_size, self.input_tensor_size)
+            )
+        elif (
+            self.batch_size > 1
+            and self.global_step_only is False
+            and self.processing_accelerate_image_size > 0
+        ):
+            # If we have more than one image and we use local step,
+            # we can use image processing acceleration for accelerate high resolution image processing
+            # but we need to resize image to processing_accelerate_image_size to stack it with other images
+            preprocessed_data = preprocessed_data.resize(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+        elif (
+            self.batch_size > 1
+            and self.global_step_only is False
+            and not (self.processing_accelerate_image_size > 0)
+        ):
+            raise ValueError(
+                "If you use local step with batch_size > 2, "
+                "you need to set processing_accelerate_image_size > 0,"
+                "since we cannot stack images with different sizes to one batch"
+            )
+        else:  # some extra cases
+            preprocessed_data = preprocessed_data.resize(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+
+        if data.mode == "RGB":
+            preprocessed_data = self._image_transform(
+                np.array(preprocessed_data)
+            ).unsqueeze(0)
+        elif data.mode == "L":
+            preprocessed_data = np.array(preprocessed_data)
+            if 0 < self.mask_binary_threshold <= 255:
+                preprocessed_data = (
+                    preprocessed_data > self.mask_binary_threshold
+                ).astype(np.uint8) * 255
+            elif self.mask_binary_threshold > 255 or self.mask_binary_threshold < 0:
+                warnings.warn(
+                    "mask_binary_threshold should be in range [0, 255], "
+                    "but got {}. Disabling mask_binary_threshold!".format(
+                        self.mask_binary_threshold
+                    )
+                )
+
+            preprocessed_data = self._seg_transform(preprocessed_data).unsqueeze(
+                0
+            )  # [H,W,1]
+
+        return preprocessed_data
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, mask: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+            mask: input mask
+
+        Returns:
+            Segmentation mask as PIL Image instance
+
+        """
+        refined_mask = (data[0, :, :].cpu().numpy() * 255).astype("uint8")
+        return Image.fromarray(refined_mask).convert("L").resize(mask.size)
+
+    def safe_forward(self, im, seg, inter_s8=None, inter_s4=None):
+        """
+        Slightly pads the input image such that its length is a multiple of 8
+        """
+        b, _, ph, pw = seg.shape
+        if (ph % 8 != 0) or (pw % 8 != 0):
+            newH = (ph // 8 + 1) * 8
+            newW = (pw // 8 + 1) * 8
+            p_im = torch.zeros(b, 3, newH, newW, device=im.device)
+            p_seg = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+
+            p_im[:, :, 0:ph, 0:pw] = im
+            p_seg[:, :, 0:ph, 0:pw] = seg
+            im = p_im
+            seg = p_seg
+
+            if inter_s8 is not None:
+                p_inter_s8 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+                p_inter_s8[:, :, 0:ph, 0:pw] = inter_s8
+                inter_s8 = p_inter_s8
+            if inter_s4 is not None:
+                p_inter_s4 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+                p_inter_s4[:, :, 0:ph, 0:pw] = inter_s4
+                inter_s4 = p_inter_s4
+
+        images = super().__call__(im, seg, inter_s8, inter_s4)
+        return_im = {}
+
+        for key in ["pred_224", "pred_28_3", "pred_56_2"]:
+            return_im[key] = images[key][:, :, 0:ph, 0:pw]
+        del images
+
+        return return_im
+
+    def __call__(
+        self,
+        images: List[Union[str, pathlib.Path, PIL.Image.Image]],
+        masks: List[Union[str, pathlib.Path, PIL.Image.Image]],
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images: input images
+            masks: Segmentation masks to refine
+
+        Returns:
+            segmentation masks as for input images, as PIL.Image.Image instances
+
+        """
+
+        if len(images) != len(masks):
+            raise ValueError(
+                "Len of specified arrays of images and trimaps should be equal!"
+            )
+
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for idx_batch in batch_generator(range(len(images)), self.batch_size):
+                inpt_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(images[x])), idx_batch
+                )
+
+                inpt_masks = thread_pool_processing(
+                    lambda x: convert_image(load_image(masks[x]), mode="L"), idx_batch
+                )
+
+                inpt_img_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_images
+                )
+                inpt_masks_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_masks
+                )
+                if self.batch_size > 1:  # We need to stack images, if batch_size > 1
+                    inpt_img_batches = torch.vstack(inpt_img_batches)
+                    inpt_masks_batches = torch.vstack(inpt_masks_batches)
+                else:
+                    inpt_img_batches = inpt_img_batches[
+                        0
+                    ]  # Get only one image from list
+                    inpt_masks_batches = inpt_masks_batches[0]
+
+                with torch.no_grad():
+                    inpt_img_batches = inpt_img_batches.to(self.device)
+                    inpt_masks_batches = inpt_masks_batches.to(self.device)
+                    if self.global_step_only:
+                        refined_batches = process_im_single_pass(
+                            self,
+                            inpt_img_batches,
+                            inpt_masks_batches,
+                            self.input_tensor_size,
+                        )
+
+                    else:
+                        refined_batches = process_high_res_im(
+                            self,
+                            inpt_img_batches,
+                            inpt_masks_batches,
+                            self.input_tensor_size,
+                        )
+
+                    refined_masks = refined_batches.cpu()
+                    del (inpt_img_batches, inpt_masks_batches, refined_batches)
+                collect_masks += thread_pool_processing(
+                    lambda x: self.data_postprocessing(refined_masks[x], inpt_masks[x]),
+                    range(len(inpt_masks)),
+                )
+            return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class CascadePSP +(device='cpu', input_tensor_size:Β intΒ =Β 900, batch_size:Β intΒ =Β 1, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False, mask_binary_threshold=127, global_step_only=False, processing_accelerate_image_size=2048) +
+
+

CascadePSP to refine the mask from segmentation network

+

Initialize the CascadePSP model

+

Args

+
+
device
+
processing device
+
input_tensor_size
+
input image size
+
batch_size
+
the number of images that the neural network processes in one run
+
load_pretrained
+
loading pretrained model
+
fp16
+
use half precision
+
global_step_only
+
if True, only global step will be used for prediction. See paper for details.
+
mask_binary_threshold
+
threshold for binary mask, default 70, set to 0 for no threshold
+
processing_accelerate_image_size
+
thumbnail size for image processing acceleration. Set to 0 to disable
+
+
+ +Expand source code + +
class CascadePSP(RefinementModule):
+    """
+    CascadePSP to refine the mask from segmentation network
+    """
+
+    def __init__(
+        self,
+        device="cpu",
+        input_tensor_size: int = 900,
+        batch_size: int = 1,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+        mask_binary_threshold=127,
+        global_step_only=False,
+        processing_accelerate_image_size=2048,
+    ):
+        """
+        Initialize the CascadePSP model
+
+        Args:
+            device: processing device
+            input_tensor_size: input image size
+            batch_size: the number of images that the neural network processes in one run
+            load_pretrained: loading pretrained model
+            fp16: use half precision
+            global_step_only: if True, only global step will be used for prediction. See paper for details.
+            mask_binary_threshold: threshold for binary mask, default 70, set to 0 for no threshold
+            processing_accelerate_image_size: thumbnail size for image processing acceleration. Set to 0 to disable
+
+        """
+        super().__init__()
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        self.mask_binary_threshold = mask_binary_threshold
+        self.global_step_only = global_step_only
+        self.processing_accelerate_image_size = processing_accelerate_image_size
+        self.input_tensor_size = input_tensor_size
+
+        self.to(device)
+        if batch_size > 1:
+            warnings.warn(
+                "Batch size > 1 is experimental feature for CascadePSP."
+                " Please, don't use it if you have GPU with small memory!"
+            )
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(cascadepsp_pretrained(), map_location=self.device)
+            )
+        self.eval()
+
+        self._image_transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+
+        self._seg_transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(mean=[0.5], std=[0.5]),
+            ]
+        )
+
+    def data_preprocessing(self, data: Union[PIL.Image.Image]) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data: input image
+
+        Returns:
+            input for neural network
+
+        """
+        preprocessed_data = data.copy()
+        if self.batch_size == 1 and self.processing_accelerate_image_size > 0:
+            # Okay, we have only one image, so
+            # we can use image processing acceleration for accelerate high resolution image processing
+            preprocessed_data.thumbnail(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+        elif self.batch_size == 1:
+            pass  # No need to do anything
+        elif self.batch_size > 1 and self.global_step_only is True:
+            # If we have more than one image and we use only global step,
+            # there aren't any reason to use image processing acceleration,
+            # because we will use only global step for prediction and anyway it will be resized to input_tensor_size
+            preprocessed_data = preprocessed_data.resize(
+                (self.input_tensor_size, self.input_tensor_size)
+            )
+        elif (
+            self.batch_size > 1
+            and self.global_step_only is False
+            and self.processing_accelerate_image_size > 0
+        ):
+            # If we have more than one image and we use local step,
+            # we can use image processing acceleration for accelerate high resolution image processing
+            # but we need to resize image to processing_accelerate_image_size to stack it with other images
+            preprocessed_data = preprocessed_data.resize(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+        elif (
+            self.batch_size > 1
+            and self.global_step_only is False
+            and not (self.processing_accelerate_image_size > 0)
+        ):
+            raise ValueError(
+                "If you use local step with batch_size > 2, "
+                "you need to set processing_accelerate_image_size > 0,"
+                "since we cannot stack images with different sizes to one batch"
+            )
+        else:  # some extra cases
+            preprocessed_data = preprocessed_data.resize(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+
+        if data.mode == "RGB":
+            preprocessed_data = self._image_transform(
+                np.array(preprocessed_data)
+            ).unsqueeze(0)
+        elif data.mode == "L":
+            preprocessed_data = np.array(preprocessed_data)
+            if 0 < self.mask_binary_threshold <= 255:
+                preprocessed_data = (
+                    preprocessed_data > self.mask_binary_threshold
+                ).astype(np.uint8) * 255
+            elif self.mask_binary_threshold > 255 or self.mask_binary_threshold < 0:
+                warnings.warn(
+                    "mask_binary_threshold should be in range [0, 255], "
+                    "but got {}. Disabling mask_binary_threshold!".format(
+                        self.mask_binary_threshold
+                    )
+                )
+
+            preprocessed_data = self._seg_transform(preprocessed_data).unsqueeze(
+                0
+            )  # [H,W,1]
+
+        return preprocessed_data
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, mask: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+            mask: input mask
+
+        Returns:
+            Segmentation mask as PIL Image instance
+
+        """
+        refined_mask = (data[0, :, :].cpu().numpy() * 255).astype("uint8")
+        return Image.fromarray(refined_mask).convert("L").resize(mask.size)
+
+    def safe_forward(self, im, seg, inter_s8=None, inter_s4=None):
+        """
+        Slightly pads the input image such that its length is a multiple of 8
+        """
+        b, _, ph, pw = seg.shape
+        if (ph % 8 != 0) or (pw % 8 != 0):
+            newH = (ph // 8 + 1) * 8
+            newW = (pw // 8 + 1) * 8
+            p_im = torch.zeros(b, 3, newH, newW, device=im.device)
+            p_seg = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+
+            p_im[:, :, 0:ph, 0:pw] = im
+            p_seg[:, :, 0:ph, 0:pw] = seg
+            im = p_im
+            seg = p_seg
+
+            if inter_s8 is not None:
+                p_inter_s8 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+                p_inter_s8[:, :, 0:ph, 0:pw] = inter_s8
+                inter_s8 = p_inter_s8
+            if inter_s4 is not None:
+                p_inter_s4 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+                p_inter_s4[:, :, 0:ph, 0:pw] = inter_s4
+                inter_s4 = p_inter_s4
+
+        images = super().__call__(im, seg, inter_s8, inter_s4)
+        return_im = {}
+
+        for key in ["pred_224", "pred_28_3", "pred_56_2"]:
+            return_im[key] = images[key][:, :, 0:ph, 0:pw]
+        del images
+
+        return return_im
+
+    def __call__(
+        self,
+        images: List[Union[str, pathlib.Path, PIL.Image.Image]],
+        masks: List[Union[str, pathlib.Path, PIL.Image.Image]],
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images: input images
+            masks: Segmentation masks to refine
+
+        Returns:
+            segmentation masks as for input images, as PIL.Image.Image instances
+
+        """
+
+        if len(images) != len(masks):
+            raise ValueError(
+                "Len of specified arrays of images and trimaps should be equal!"
+            )
+
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for idx_batch in batch_generator(range(len(images)), self.batch_size):
+                inpt_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(images[x])), idx_batch
+                )
+
+                inpt_masks = thread_pool_processing(
+                    lambda x: convert_image(load_image(masks[x]), mode="L"), idx_batch
+                )
+
+                inpt_img_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_images
+                )
+                inpt_masks_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_masks
+                )
+                if self.batch_size > 1:  # We need to stack images, if batch_size > 1
+                    inpt_img_batches = torch.vstack(inpt_img_batches)
+                    inpt_masks_batches = torch.vstack(inpt_masks_batches)
+                else:
+                    inpt_img_batches = inpt_img_batches[
+                        0
+                    ]  # Get only one image from list
+                    inpt_masks_batches = inpt_masks_batches[0]
+
+                with torch.no_grad():
+                    inpt_img_batches = inpt_img_batches.to(self.device)
+                    inpt_masks_batches = inpt_masks_batches.to(self.device)
+                    if self.global_step_only:
+                        refined_batches = process_im_single_pass(
+                            self,
+                            inpt_img_batches,
+                            inpt_masks_batches,
+                            self.input_tensor_size,
+                        )
+
+                    else:
+                        refined_batches = process_high_res_im(
+                            self,
+                            inpt_img_batches,
+                            inpt_masks_batches,
+                            self.input_tensor_size,
+                        )
+
+                    refined_masks = refined_batches.cpu()
+                    del (inpt_img_batches, inpt_masks_batches, refined_batches)
+                collect_masks += thread_pool_processing(
+                    lambda x: self.data_postprocessing(refined_masks[x], inpt_masks[x]),
+                    range(len(inpt_masks)),
+                )
+            return collect_masks
+
+

Ancestors

+ +

Static methods

+
+
+def data_postprocessing(data:Β torch.Tensor, mask:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data
+
output data from neural network
+
mask
+
input mask
+
+

Returns

+

Segmentation mask as PIL Image instance

+
+ +Expand source code + +
@staticmethod
+def data_postprocessing(
+    data: torch.Tensor, mask: PIL.Image.Image
+) -> PIL.Image.Image:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data: output data from neural network
+        mask: input mask
+
+    Returns:
+        Segmentation mask as PIL Image instance
+
+    """
+    refined_mask = (data[0, :, :].cpu().numpy() * 255).astype("uint8")
+    return Image.fromarray(refined_mask).convert("L").resize(mask.size)
+
+
+
+

Methods

+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.FloatTensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data
+
input image
+
+

Returns

+

input for neural network

+
+ +Expand source code + +
def data_preprocessing(self, data: Union[PIL.Image.Image]) -> torch.FloatTensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data: input image
+
+    Returns:
+        input for neural network
+
+    """
+    preprocessed_data = data.copy()
+    if self.batch_size == 1 and self.processing_accelerate_image_size > 0:
+        # Okay, we have only one image, so
+        # we can use image processing acceleration for accelerate high resolution image processing
+        preprocessed_data.thumbnail(
+            (
+                self.processing_accelerate_image_size,
+                self.processing_accelerate_image_size,
+            )
+        )
+    elif self.batch_size == 1:
+        pass  # No need to do anything
+    elif self.batch_size > 1 and self.global_step_only is True:
+        # If we have more than one image and we use only global step,
+        # there aren't any reason to use image processing acceleration,
+        # because we will use only global step for prediction and anyway it will be resized to input_tensor_size
+        preprocessed_data = preprocessed_data.resize(
+            (self.input_tensor_size, self.input_tensor_size)
+        )
+    elif (
+        self.batch_size > 1
+        and self.global_step_only is False
+        and self.processing_accelerate_image_size > 0
+    ):
+        # If we have more than one image and we use local step,
+        # we can use image processing acceleration for accelerate high resolution image processing
+        # but we need to resize image to processing_accelerate_image_size to stack it with other images
+        preprocessed_data = preprocessed_data.resize(
+            (
+                self.processing_accelerate_image_size,
+                self.processing_accelerate_image_size,
+            )
+        )
+    elif (
+        self.batch_size > 1
+        and self.global_step_only is False
+        and not (self.processing_accelerate_image_size > 0)
+    ):
+        raise ValueError(
+            "If you use local step with batch_size > 2, "
+            "you need to set processing_accelerate_image_size > 0,"
+            "since we cannot stack images with different sizes to one batch"
+        )
+    else:  # some extra cases
+        preprocessed_data = preprocessed_data.resize(
+            (
+                self.processing_accelerate_image_size,
+                self.processing_accelerate_image_size,
+            )
+        )
+
+    if data.mode == "RGB":
+        preprocessed_data = self._image_transform(
+            np.array(preprocessed_data)
+        ).unsqueeze(0)
+    elif data.mode == "L":
+        preprocessed_data = np.array(preprocessed_data)
+        if 0 < self.mask_binary_threshold <= 255:
+            preprocessed_data = (
+                preprocessed_data > self.mask_binary_threshold
+            ).astype(np.uint8) * 255
+        elif self.mask_binary_threshold > 255 or self.mask_binary_threshold < 0:
+            warnings.warn(
+                "mask_binary_threshold should be in range [0, 255], "
+                "but got {}. Disabling mask_binary_threshold!".format(
+                    self.mask_binary_threshold
+                )
+            )
+
+        preprocessed_data = self._seg_transform(preprocessed_data).unsqueeze(
+            0
+        )  # [H,W,1]
+
+    return preprocessed_data
+
+
+
+def safe_forward(self, im, seg, inter_s8=None, inter_s4=None) +
+
+

Slightly pads the input image such that its length is a multiple of 8

+
+ +Expand source code + +
def safe_forward(self, im, seg, inter_s8=None, inter_s4=None):
+    """
+    Slightly pads the input image such that its length is a multiple of 8
+    """
+    b, _, ph, pw = seg.shape
+    if (ph % 8 != 0) or (pw % 8 != 0):
+        newH = (ph // 8 + 1) * 8
+        newW = (pw // 8 + 1) * 8
+        p_im = torch.zeros(b, 3, newH, newW, device=im.device)
+        p_seg = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+
+        p_im[:, :, 0:ph, 0:pw] = im
+        p_seg[:, :, 0:ph, 0:pw] = seg
+        im = p_im
+        seg = p_seg
+
+        if inter_s8 is not None:
+            p_inter_s8 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+            p_inter_s8[:, :, 0:ph, 0:pw] = inter_s8
+            inter_s8 = p_inter_s8
+        if inter_s4 is not None:
+            p_inter_s4 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+            p_inter_s4[:, :, 0:ph, 0:pw] = inter_s4
+            inter_s4 = p_inter_s4
+
+    images = super().__call__(im, seg, inter_s8, inter_s4)
+    return_im = {}
+
+    for key in ["pred_224", "pred_28_3", "pred_56_2"]:
+        return_im[key] = images[key][:, :, 0:ph, 0:pw]
+    del images
+
+    return return_im
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/wrap/deeplab_v3.html b/docs/api/carvekit/ml/wrap/deeplab_v3.html new file mode 100644 index 0000000..57e1eb1 --- /dev/null +++ b/docs/api/carvekit/ml/wrap/deeplab_v3.html @@ -0,0 +1,490 @@ + + + + + + +carvekit.ml.wrap.deeplab_v3 API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.wrap.deeplab_v3

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+import pathlib
+from typing import List, Union
+
+import PIL.Image
+import torch
+from PIL import Image
+from torchvision import transforms
+from torchvision.models.segmentation import deeplabv3_resnet101
+from carvekit.ml.files.models_loc import deeplab_pretrained
+from carvekit.utils.image_utils import convert_image, load_image
+from carvekit.utils.models_utils import get_precision_autocast, cast_network
+from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
+
+__all__ = ["DeepLabV3"]
+
+
+class DeepLabV3:
+    def __init__(
+        self,
+        device="cpu",
+        batch_size: int = 10,
+        input_image_size: Union[List[int], int] = 1024,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the `DeepLabV3` model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (): input image size
+            batch_size (int, default=10): the number of images that the neural network processes in one run
+            load_pretrained (bool, default=True): loading pretrained model
+            fp16 (bool, default=False): use half precision
+
+        """
+        self.device = device
+        self.batch_size = batch_size
+        self.network = deeplabv3_resnet101(
+            pretrained=False, pretrained_backbone=False, aux_loss=True
+        )
+        self.network.to(self.device)
+        if load_pretrained:
+            self.network.load_state_dict(
+                torch.load(deeplab_pretrained(), map_location=self.device)
+            )
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.network.eval()
+        self.fp16 = fp16
+        self.transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+
+    def to(self, device: str):
+        """
+        Moves neural network to specified processing device
+
+        Args:
+            device (Literal[cpu, cuda]): the desired device.
+
+        """
+        self.network.to(device)
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.Tensor: input for neural network
+
+        """
+        copy = data.copy()
+        copy.thumbnail(self.input_image_size, resample=3)
+        return self.transform(copy)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        return (
+            Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size)
+        )
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as `PIL.Image.Image` instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images, as `PIL.Image.Image` instances
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self.network, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = thread_pool_processing(
+                    self.data_preprocessing, converted_images
+                )
+                with torch.no_grad():
+                    masks = [
+                        self.network(i.to(self.device).unsqueeze(0))["out"][0]
+                        .argmax(0)
+                        .byte()
+                        .cpu()
+                        for i in batches
+                    ]
+                    del batches
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(masks[x], converted_images[x]),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+        return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class DeepLabV3 +(device='cpu', batch_size:Β intΒ =Β 10, input_image_size:Β Union[List[int],Β int]Β =Β 1024, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False) +
+
+

Initialize the DeepLabV3 model

+

Args

+
+
device : Literal[cpu, cuda], default=cpu
+
processing device
+
input_image_size (): input image size
+
batch_size : int, default=10
+
the number of images that the neural network processes in one run
+
load_pretrained : bool, default=True
+
loading pretrained model
+
fp16 : bool, default=False
+
use half precision
+
+
+ +Expand source code + +
class DeepLabV3:
+    def __init__(
+        self,
+        device="cpu",
+        batch_size: int = 10,
+        input_image_size: Union[List[int], int] = 1024,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the `DeepLabV3` model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (): input image size
+            batch_size (int, default=10): the number of images that the neural network processes in one run
+            load_pretrained (bool, default=True): loading pretrained model
+            fp16 (bool, default=False): use half precision
+
+        """
+        self.device = device
+        self.batch_size = batch_size
+        self.network = deeplabv3_resnet101(
+            pretrained=False, pretrained_backbone=False, aux_loss=True
+        )
+        self.network.to(self.device)
+        if load_pretrained:
+            self.network.load_state_dict(
+                torch.load(deeplab_pretrained(), map_location=self.device)
+            )
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.network.eval()
+        self.fp16 = fp16
+        self.transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+
+    def to(self, device: str):
+        """
+        Moves neural network to specified processing device
+
+        Args:
+            device (Literal[cpu, cuda]): the desired device.
+
+        """
+        self.network.to(device)
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.Tensor: input for neural network
+
+        """
+        copy = data.copy()
+        copy.thumbnail(self.input_image_size, resample=3)
+        return self.transform(copy)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        return (
+            Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size)
+        )
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as `PIL.Image.Image` instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images, as `PIL.Image.Image` instances
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self.network, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = thread_pool_processing(
+                    self.data_preprocessing, converted_images
+                )
+                with torch.no_grad():
+                    masks = [
+                        self.network(i.to(self.device).unsqueeze(0))["out"][0]
+                        .argmax(0)
+                        .byte()
+                        .cpu()
+                        for i in batches
+                    ]
+                    del batches
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(masks[x], converted_images[x]),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+        return collect_masks
+
+

Static methods

+
+
+def data_postprocessing(data:Β torch.Tensor, original_image:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data : torch.Tensor
+
output data from neural network
+
original_image : PIL.Image.Image
+
input image which was used for predicted data
+
+

Returns

+
+
PIL.Image.Image
+
Segmentation mask as PIL Image instance
+
+
+ +Expand source code + +
@staticmethod
+def data_postprocessing(
+    data: torch.Tensor, original_image: PIL.Image.Image
+) -> PIL.Image.Image:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data (torch.Tensor): output data from neural network
+        original_image (PIL.Image.Image): input image which was used for predicted data
+
+    Returns:
+        PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+    """
+    return (
+        Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size)
+    )
+
+
+
+

Methods

+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.Tensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data : PIL.Image.Image
+
input image
+
+

Returns

+
+
torch.Tensor
+
input for neural network
+
+
+ +Expand source code + +
def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data (PIL.Image.Image): input image
+
+    Returns:
+        torch.Tensor: input for neural network
+
+    """
+    copy = data.copy()
+    copy.thumbnail(self.input_image_size, resample=3)
+    return self.transform(copy)
+
+
+
+def to(self, device:Β str) +
+
+

Moves neural network to specified processing device

+

Args

+
+
device : Literal[cpu, cuda]
+
the desired device.
+
+
+ +Expand source code + +
def to(self, device: str):
+    """
+    Moves neural network to specified processing device
+
+    Args:
+        device (Literal[cpu, cuda]): the desired device.
+
+    """
+    self.network.to(device)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/wrap/fba_matting.html b/docs/api/carvekit/ml/wrap/fba_matting.html new file mode 100644 index 0000000..9aee452 --- /dev/null +++ b/docs/api/carvekit/ml/wrap/fba_matting.html @@ -0,0 +1,674 @@ + + + + + + +carvekit.ml.wrap.fba_matting API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.wrap.fba_matting

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import pathlib
+from typing import Union, List, Tuple
+
+import PIL
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+
+from carvekit.ml.arch.fba_matting.models import FBA
+from carvekit.ml.arch.fba_matting.transforms import (
+    trimap_transform,
+    groupnorm_normalise_image,
+)
+from carvekit.ml.files.models_loc import fba_pretrained
+from carvekit.utils.image_utils import convert_image, load_image
+from carvekit.utils.models_utils import get_precision_autocast, cast_network
+from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
+
+__all__ = ["FBAMatting"]
+
+
+class FBAMatting(FBA):
+    """
+    FBA Matting Neural Network to improve edges on image.
+    """
+
+    def __init__(
+        self,
+        device="cpu",
+        input_tensor_size: Union[List[int], int] = 2048,
+        batch_size: int = 2,
+        encoder="resnet50_GN_WS",
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the FBAMatting model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_tensor_size (Union[List[int], int], default=2048): input image size
+            batch_size (int, default=2): the number of images that the neural network processes in one run
+            encoder (str, default=resnet50_GN_WS): neural network encoder head
+            .. TODO::
+                Add more encoders to documentation as Literal typehint.
+            load_pretrained (bool, default=True): loading pretrained model
+            fp16 (bool, default=False): use half precision
+
+        """
+        super(FBAMatting, self).__init__(encoder=encoder)
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_tensor_size, list):
+            self.input_image_size = input_tensor_size[:2]
+        else:
+            self.input_image_size = (input_tensor_size, input_tensor_size)
+        self.to(device)
+        if load_pretrained:
+            self.load_state_dict(torch.load(fba_pretrained(), map_location=self.device))
+        self.eval()
+
+    def data_preprocessing(
+        self, data: Union[PIL.Image.Image, np.ndarray]
+    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (Union[PIL.Image.Image, np.ndarray]): input image
+
+        Returns:
+            Tuple[torch.FloatTensor, torch.FloatTensor]: input for neural network
+
+        """
+        resized = data.copy()
+        if self.batch_size == 1:
+            resized.thumbnail(self.input_image_size, resample=3)
+        else:
+            resized = resized.resize(self.input_image_size, resample=3)
+        # noinspection PyTypeChecker
+        image = np.array(resized, dtype=np.float64)
+        image = image / 255.0  # Normalize image to [0, 1] values range
+        if resized.mode == "RGB":
+            image = image[:, :, ::-1]
+        elif resized.mode == "L":
+            image2 = np.copy(image)
+            h, w = image2.shape
+            image = np.zeros((h, w, 2))  # Transform trimap to binary data format
+            image[image2 == 1, 1] = 1
+            image[image2 == 0, 0] = 1
+        else:
+            raise ValueError("Incorrect color mode for image")
+        h, w = image.shape[:2]  # Scale input mlt to 8
+        h1 = int(np.ceil(1.0 * h / 8) * 8)
+        w1 = int(np.ceil(1.0 * w / 8) * 8)
+        x_scale = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_LANCZOS4)
+        image_tensor = torch.from_numpy(x_scale).permute(2, 0, 1)[None, :, :, :].float()
+        if resized.mode == "RGB":
+            return image_tensor, groupnorm_normalise_image(
+                image_tensor.clone(), format="nchw"
+            )
+        else:
+            return (
+                image_tensor,
+                torch.from_numpy(trimap_transform(x_scale))
+                .permute(2, 0, 1)[None, :, :, :]
+                .float(),
+            )
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, trimap: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            trimap (PIL.Image.Image): Map with the area we need to refine
+
+        Returns:
+            PIL.Image.Image: Segmentation mask
+
+        """
+        if trimap.mode != "L":
+            raise ValueError("Incorrect color mode for trimap")
+        pred = data.numpy().transpose((1, 2, 0))
+        pred = cv2.resize(pred, trimap.size, cv2.INTER_LANCZOS4)[:, :, 0]
+        # noinspection PyTypeChecker
+        # Clean mask by removing all false predictions outside trimap and already known area
+        trimap_arr = np.array(trimap.copy())
+        pred[trimap_arr[:, :] == 0] = 0
+        # pred[trimap_arr[:, :] == 255] = 1
+        pred[pred < 0.3] = 0
+        return Image.fromarray(pred * 255).convert("L")
+
+    def __call__(
+        self,
+        images: List[Union[str, pathlib.Path, PIL.Image.Image]],
+        trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]],
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+            trimaps (List[Union[str, pathlib.Path, PIL.Image.Image]]): Maps with the areas we need to refine
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images
+
+        """
+
+        if len(images) != len(trimaps):
+            raise ValueError(
+                "Len of specified arrays of images and trimaps should be equal!"
+            )
+
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for idx_batch in batch_generator(range(len(images)), self.batch_size):
+                inpt_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(images[x])), idx_batch
+                )
+
+                inpt_trimaps = thread_pool_processing(
+                    lambda x: convert_image(load_image(trimaps[x]), mode="L"), idx_batch
+                )
+
+                inpt_img_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_images
+                )
+                inpt_trimaps_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_trimaps
+                )
+
+                inpt_img_batches_transformed = torch.vstack(
+                    [i[1] for i in inpt_img_batches]
+                )
+                inpt_img_batches = torch.vstack([i[0] for i in inpt_img_batches])
+
+                inpt_trimaps_transformed = torch.vstack(
+                    [i[1] for i in inpt_trimaps_batches]
+                )
+                inpt_trimaps_batches = torch.vstack(
+                    [i[0] for i in inpt_trimaps_batches]
+                )
+
+                with torch.no_grad():
+                    inpt_img_batches = inpt_img_batches.to(self.device)
+                    inpt_trimaps_batches = inpt_trimaps_batches.to(self.device)
+                    inpt_img_batches_transformed = inpt_img_batches_transformed.to(
+                        self.device
+                    )
+                    inpt_trimaps_transformed = inpt_trimaps_transformed.to(self.device)
+
+                    output = super(FBAMatting, self).__call__(
+                        inpt_img_batches,
+                        inpt_trimaps_batches,
+                        inpt_img_batches_transformed,
+                        inpt_trimaps_transformed,
+                    )
+                    output_cpu = output.cpu()
+                    del (
+                        inpt_img_batches,
+                        inpt_trimaps_batches,
+                        inpt_img_batches_transformed,
+                        inpt_trimaps_transformed,
+                        output,
+                    )
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(output_cpu[x], inpt_trimaps[x]),
+                    range(len(inpt_images)),
+                )
+                collect_masks += masks
+            return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class FBAMatting +(device='cpu', input_tensor_size:Β Union[List[int],Β int]Β =Β 2048, batch_size:Β intΒ =Β 2, encoder='resnet50_GN_WS', load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False) +
+
+

FBA Matting Neural Network to improve edges on image.

+

Initialize the FBAMatting model

+

Args

+
+
device : Literal[cpu, cuda], default=cpu
+
processing device
+
input_tensor_size : Union[List[int], int], default=2048
+
input image size
+
batch_size : int, default=2
+
the number of images that the neural network processes in one run
+
encoder : str, default=resnet50_GN_WS
+
neural network encoder head
+
+
+

TODO

+

Add more encoders to documentation as Literal typehint.

+
+
+
load_pretrained : bool, default=True
+
loading pretrained model
+
fp16 : bool, default=False
+
use half precision
+
+
+ +Expand source code + +
class FBAMatting(FBA):
+    """
+    FBA Matting Neural Network to improve edges on image.
+    """
+
+    def __init__(
+        self,
+        device="cpu",
+        input_tensor_size: Union[List[int], int] = 2048,
+        batch_size: int = 2,
+        encoder="resnet50_GN_WS",
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the FBAMatting model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_tensor_size (Union[List[int], int], default=2048): input image size
+            batch_size (int, default=2): the number of images that the neural network processes in one run
+            encoder (str, default=resnet50_GN_WS): neural network encoder head
+            .. TODO::
+                Add more encoders to documentation as Literal typehint.
+            load_pretrained (bool, default=True): loading pretrained model
+            fp16 (bool, default=False): use half precision
+
+        """
+        super(FBAMatting, self).__init__(encoder=encoder)
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_tensor_size, list):
+            self.input_image_size = input_tensor_size[:2]
+        else:
+            self.input_image_size = (input_tensor_size, input_tensor_size)
+        self.to(device)
+        if load_pretrained:
+            self.load_state_dict(torch.load(fba_pretrained(), map_location=self.device))
+        self.eval()
+
+    def data_preprocessing(
+        self, data: Union[PIL.Image.Image, np.ndarray]
+    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (Union[PIL.Image.Image, np.ndarray]): input image
+
+        Returns:
+            Tuple[torch.FloatTensor, torch.FloatTensor]: input for neural network
+
+        """
+        resized = data.copy()
+        if self.batch_size == 1:
+            resized.thumbnail(self.input_image_size, resample=3)
+        else:
+            resized = resized.resize(self.input_image_size, resample=3)
+        # noinspection PyTypeChecker
+        image = np.array(resized, dtype=np.float64)
+        image = image / 255.0  # Normalize image to [0, 1] values range
+        if resized.mode == "RGB":
+            image = image[:, :, ::-1]
+        elif resized.mode == "L":
+            image2 = np.copy(image)
+            h, w = image2.shape
+            image = np.zeros((h, w, 2))  # Transform trimap to binary data format
+            image[image2 == 1, 1] = 1
+            image[image2 == 0, 0] = 1
+        else:
+            raise ValueError("Incorrect color mode for image")
+        h, w = image.shape[:2]  # Scale input mlt to 8
+        h1 = int(np.ceil(1.0 * h / 8) * 8)
+        w1 = int(np.ceil(1.0 * w / 8) * 8)
+        x_scale = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_LANCZOS4)
+        image_tensor = torch.from_numpy(x_scale).permute(2, 0, 1)[None, :, :, :].float()
+        if resized.mode == "RGB":
+            return image_tensor, groupnorm_normalise_image(
+                image_tensor.clone(), format="nchw"
+            )
+        else:
+            return (
+                image_tensor,
+                torch.from_numpy(trimap_transform(x_scale))
+                .permute(2, 0, 1)[None, :, :, :]
+                .float(),
+            )
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, trimap: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            trimap (PIL.Image.Image): Map with the area we need to refine
+
+        Returns:
+            PIL.Image.Image: Segmentation mask
+
+        """
+        if trimap.mode != "L":
+            raise ValueError("Incorrect color mode for trimap")
+        pred = data.numpy().transpose((1, 2, 0))
+        pred = cv2.resize(pred, trimap.size, cv2.INTER_LANCZOS4)[:, :, 0]
+        # noinspection PyTypeChecker
+        # Clean mask by removing all false predictions outside trimap and already known area
+        trimap_arr = np.array(trimap.copy())
+        pred[trimap_arr[:, :] == 0] = 0
+        # pred[trimap_arr[:, :] == 255] = 1
+        pred[pred < 0.3] = 0
+        return Image.fromarray(pred * 255).convert("L")
+
+    def __call__(
+        self,
+        images: List[Union[str, pathlib.Path, PIL.Image.Image]],
+        trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]],
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+            trimaps (List[Union[str, pathlib.Path, PIL.Image.Image]]): Maps with the areas we need to refine
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images
+
+        """
+
+        if len(images) != len(trimaps):
+            raise ValueError(
+                "Len of specified arrays of images and trimaps should be equal!"
+            )
+
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for idx_batch in batch_generator(range(len(images)), self.batch_size):
+                inpt_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(images[x])), idx_batch
+                )
+
+                inpt_trimaps = thread_pool_processing(
+                    lambda x: convert_image(load_image(trimaps[x]), mode="L"), idx_batch
+                )
+
+                inpt_img_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_images
+                )
+                inpt_trimaps_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_trimaps
+                )
+
+                inpt_img_batches_transformed = torch.vstack(
+                    [i[1] for i in inpt_img_batches]
+                )
+                inpt_img_batches = torch.vstack([i[0] for i in inpt_img_batches])
+
+                inpt_trimaps_transformed = torch.vstack(
+                    [i[1] for i in inpt_trimaps_batches]
+                )
+                inpt_trimaps_batches = torch.vstack(
+                    [i[0] for i in inpt_trimaps_batches]
+                )
+
+                with torch.no_grad():
+                    inpt_img_batches = inpt_img_batches.to(self.device)
+                    inpt_trimaps_batches = inpt_trimaps_batches.to(self.device)
+                    inpt_img_batches_transformed = inpt_img_batches_transformed.to(
+                        self.device
+                    )
+                    inpt_trimaps_transformed = inpt_trimaps_transformed.to(self.device)
+
+                    output = super(FBAMatting, self).__call__(
+                        inpt_img_batches,
+                        inpt_trimaps_batches,
+                        inpt_img_batches_transformed,
+                        inpt_trimaps_transformed,
+                    )
+                    output_cpu = output.cpu()
+                    del (
+                        inpt_img_batches,
+                        inpt_trimaps_batches,
+                        inpt_img_batches_transformed,
+                        inpt_trimaps_transformed,
+                        output,
+                    )
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(output_cpu[x], inpt_trimaps[x]),
+                    range(len(inpt_images)),
+                )
+                collect_masks += masks
+            return collect_masks
+
+

Ancestors

+
    +
  • FBA
  • +
  • torch.nn.modules.module.Module
  • +
+

Static methods

+
+
+def data_postprocessing(data:Β torch.Tensor, trimap:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data : torch.Tensor
+
output data from neural network
+
trimap : PIL.Image.Image
+
Map with the area we need to refine
+
+

Returns

+
+
PIL.Image.Image
+
Segmentation mask
+
+
+ +Expand source code + +
@staticmethod
+def data_postprocessing(
+    data: torch.Tensor, trimap: PIL.Image.Image
+) -> PIL.Image.Image:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data (torch.Tensor): output data from neural network
+        trimap (PIL.Image.Image): Map with the area we need to refine
+
+    Returns:
+        PIL.Image.Image: Segmentation mask
+
+    """
+    if trimap.mode != "L":
+        raise ValueError("Incorrect color mode for trimap")
+    pred = data.numpy().transpose((1, 2, 0))
+    pred = cv2.resize(pred, trimap.size, cv2.INTER_LANCZOS4)[:, :, 0]
+    # noinspection PyTypeChecker
+    # Clean mask by removing all false predictions outside trimap and already known area
+    trimap_arr = np.array(trimap.copy())
+    pred[trimap_arr[:, :] == 0] = 0
+    # pred[trimap_arr[:, :] == 255] = 1
+    pred[pred < 0.3] = 0
+    return Image.fromarray(pred * 255).convert("L")
+
+
+
+

Methods

+
+
+def data_preprocessing(self, data:Β Union[PIL.Image.Image,Β numpy.ndarray]) ‑>Β Tuple[torch.FloatTensor,Β torch.FloatTensor] +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data : Union[PIL.Image.Image, np.ndarray]
+
input image
+
+

Returns

+
+
Tuple[torch.FloatTensor, torch.FloatTensor]
+
input for neural network
+
+
+ +Expand source code + +
def data_preprocessing(
+    self, data: Union[PIL.Image.Image, np.ndarray]
+) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data (Union[PIL.Image.Image, np.ndarray]): input image
+
+    Returns:
+        Tuple[torch.FloatTensor, torch.FloatTensor]: input for neural network
+
+    """
+    resized = data.copy()
+    if self.batch_size == 1:
+        resized.thumbnail(self.input_image_size, resample=3)
+    else:
+        resized = resized.resize(self.input_image_size, resample=3)
+    # noinspection PyTypeChecker
+    image = np.array(resized, dtype=np.float64)
+    image = image / 255.0  # Normalize image to [0, 1] values range
+    if resized.mode == "RGB":
+        image = image[:, :, ::-1]
+    elif resized.mode == "L":
+        image2 = np.copy(image)
+        h, w = image2.shape
+        image = np.zeros((h, w, 2))  # Transform trimap to binary data format
+        image[image2 == 1, 1] = 1
+        image[image2 == 0, 0] = 1
+    else:
+        raise ValueError("Incorrect color mode for image")
+    h, w = image.shape[:2]  # Scale input mlt to 8
+    h1 = int(np.ceil(1.0 * h / 8) * 8)
+    w1 = int(np.ceil(1.0 * w / 8) * 8)
+    x_scale = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_LANCZOS4)
+    image_tensor = torch.from_numpy(x_scale).permute(2, 0, 1)[None, :, :, :].float()
+    if resized.mode == "RGB":
+        return image_tensor, groupnorm_normalise_image(
+            image_tensor.clone(), format="nchw"
+        )
+    else:
+        return (
+            image_tensor,
+            torch.from_numpy(trimap_transform(x_scale))
+            .permute(2, 0, 1)[None, :, :, :]
+            .float(),
+        )
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/wrap/index.html b/docs/api/carvekit/ml/wrap/index.html new file mode 100644 index 0000000..c2ad10b --- /dev/null +++ b/docs/api/carvekit/ml/wrap/index.html @@ -0,0 +1,112 @@ + + + + + + +carvekit.ml.wrap API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.wrap

+
+
+
+
+

Sub-modules

+
+
carvekit.ml.wrap.basnet
+
+ +
+
carvekit.ml.wrap.cascadepsp
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
carvekit.ml.wrap.deeplab_v3
+
+ +
+
carvekit.ml.wrap.fba_matting
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
carvekit.ml.wrap.scene_classifier
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
carvekit.ml.wrap.tracer_b7
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
carvekit.ml.wrap.u2net
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
carvekit.ml.wrap.yolov4
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/wrap/scene_classifier.html b/docs/api/carvekit/ml/wrap/scene_classifier.html new file mode 100644 index 0000000..14c52e6 --- /dev/null +++ b/docs/api/carvekit/ml/wrap/scene_classifier.html @@ -0,0 +1,459 @@ + + + + + + +carvekit.ml.wrap.scene_classifier API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.wrap.scene_classifier

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import pathlib
+
+import PIL.Image
+import torch
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+from typing import List, Union, Tuple
+from torch.autograd import Variable
+
+from carvekit.ml.files.models_loc import scene_classifier_pretrained
+from carvekit.utils.image_utils import load_image, convert_image
+from carvekit.utils.models_utils import get_precision_autocast, cast_network
+from carvekit.utils.pool_utils import thread_pool_processing, batch_generator
+
+__all__ = ["SceneClassifier"]
+
+
+class SceneClassifier:
+    """
+    SceneClassifier model interface
+
+    Description:
+        Performs a primary analysis of the image in order to select the necessary method for removing the background.
+        The choice is made by classifying the scene type.
+
+        The output can be the following types:
+        - hard
+        - soft
+        - digital
+
+    """
+
+    def __init__(
+        self,
+        topk: int = 1,
+        device="cpu",
+        batch_size: int = 4,
+        fp16: bool = False,
+        model_path: Union[str, pathlib.Path] = None,
+    ):
+        """
+        Initialize the Scene Classifier.
+
+        Args:
+            topk: number of top classes to return
+            device: processing device
+            batch_size: the number of images that the neural network processes in one run
+            fp16: use fp16 precision
+
+        """
+        if model_path is None:
+            model_path = scene_classifier_pretrained()
+        self.topk = topk
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+
+        self.transform = transforms.Compose(
+            [
+                transforms.Resize(256),
+                transforms.CenterCrop(224),
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+        state_dict = torch.load(model_path, map_location=device)
+        self.model = state_dict["model"]
+        self.class_to_idx = state_dict["class_to_idx"]
+        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
+        self.model.to(device)
+        self.model.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data: input image
+
+        Returns:
+            input for neural network
+
+        """
+
+        return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+    def data_postprocessing(self, data: torch.Tensor) -> Tuple[List[str], List[float]]:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+
+        Returns:
+            Top-k class of scene type, probability of these classes
+
+        """
+        ps = F.softmax(data.float(), dim=0)
+        topk = ps.cpu().topk(self.topk)
+
+        probs, classes = (e.data.numpy().squeeze().tolist() for e in topk)
+        if isinstance(classes, int):
+            classes = [classes]
+            probs = [probs]
+        return list(map(lambda x: self.idx_to_class[x], classes)), probs
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> Tuple[List[str], List[float]]:
+        """
+        Passes input images though neural network and returns class predictions.
+
+        Args:
+            images: input images
+
+        Returns:
+            Top-k class of scene type, probability of these classes for every passed image
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self.model, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = torch.vstack(
+                    thread_pool_processing(self.data_preprocessing, converted_images)
+                )
+                with torch.no_grad():
+                    batches = Variable(batches).to(self.device)
+                    masks = self.model.forward(batches)
+                    masks_cpu = masks.cpu()
+                    del batches, masks
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(masks_cpu[x]),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+
+        return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class SceneClassifier +(topk:Β intΒ =Β 1, device='cpu', batch_size:Β intΒ =Β 4, fp16:Β boolΒ =Β False, model_path:Β Union[str,Β pathlib.Path]Β =Β None) +
+
+

SceneClassifier model interface

+

Description

+

Performs a primary analysis of the image in order to select the necessary method for removing the background. +The choice is made by classifying the scene type.

+

The output can be the following types: +- hard +- soft +- digital

+

Initialize the Scene Classifier.

+

Args

+
+
topk
+
number of top classes to return
+
device
+
processing device
+
batch_size
+
the number of images that the neural network processes in one run
+
fp16
+
use fp16 precision
+
+
+ +Expand source code + +
class SceneClassifier:
+    """
+    SceneClassifier model interface
+
+    Description:
+        Performs a primary analysis of the image in order to select the necessary method for removing the background.
+        The choice is made by classifying the scene type.
+
+        The output can be the following types:
+        - hard
+        - soft
+        - digital
+
+    """
+
+    def __init__(
+        self,
+        topk: int = 1,
+        device="cpu",
+        batch_size: int = 4,
+        fp16: bool = False,
+        model_path: Union[str, pathlib.Path] = None,
+    ):
+        """
+        Initialize the Scene Classifier.
+
+        Args:
+            topk: number of top classes to return
+            device: processing device
+            batch_size: the number of images that the neural network processes in one run
+            fp16: use fp16 precision
+
+        """
+        if model_path is None:
+            model_path = scene_classifier_pretrained()
+        self.topk = topk
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+
+        self.transform = transforms.Compose(
+            [
+                transforms.Resize(256),
+                transforms.CenterCrop(224),
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+        state_dict = torch.load(model_path, map_location=device)
+        self.model = state_dict["model"]
+        self.class_to_idx = state_dict["class_to_idx"]
+        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
+        self.model.to(device)
+        self.model.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data: input image
+
+        Returns:
+            input for neural network
+
+        """
+
+        return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+    def data_postprocessing(self, data: torch.Tensor) -> Tuple[List[str], List[float]]:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+
+        Returns:
+            Top-k class of scene type, probability of these classes
+
+        """
+        ps = F.softmax(data.float(), dim=0)
+        topk = ps.cpu().topk(self.topk)
+
+        probs, classes = (e.data.numpy().squeeze().tolist() for e in topk)
+        if isinstance(classes, int):
+            classes = [classes]
+            probs = [probs]
+        return list(map(lambda x: self.idx_to_class[x], classes)), probs
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> Tuple[List[str], List[float]]:
+        """
+        Passes input images though neural network and returns class predictions.
+
+        Args:
+            images: input images
+
+        Returns:
+            Top-k class of scene type, probability of these classes for every passed image
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self.model, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = torch.vstack(
+                    thread_pool_processing(self.data_preprocessing, converted_images)
+                )
+                with torch.no_grad():
+                    batches = Variable(batches).to(self.device)
+                    masks = self.model.forward(batches)
+                    masks_cpu = masks.cpu()
+                    del batches, masks
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(masks_cpu[x]),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+
+        return collect_masks
+
+

Methods

+
+
+def data_postprocessing(self, data:Β torch.Tensor) ‑>Β Tuple[List[str],Β List[float]] +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data
+
output data from neural network
+
+

Returns

+

Top-k class of scene type, probability of these classes

+
+ +Expand source code + +
def data_postprocessing(self, data: torch.Tensor) -> Tuple[List[str], List[float]]:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data: output data from neural network
+
+    Returns:
+        Top-k class of scene type, probability of these classes
+
+    """
+    ps = F.softmax(data.float(), dim=0)
+    topk = ps.cpu().topk(self.topk)
+
+    probs, classes = (e.data.numpy().squeeze().tolist() for e in topk)
+    if isinstance(classes, int):
+        classes = [classes]
+        probs = [probs]
+    return list(map(lambda x: self.idx_to_class[x], classes)), probs
+
+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.FloatTensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data
+
input image
+
+

Returns

+

input for neural network

+
+ +Expand source code + +
def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data: input image
+
+    Returns:
+        input for neural network
+
+    """
+
+    return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/wrap/tracer_b7.html b/docs/api/carvekit/ml/wrap/tracer_b7.html new file mode 100644 index 0000000..97365c5 --- /dev/null +++ b/docs/api/carvekit/ml/wrap/tracer_b7.html @@ -0,0 +1,492 @@ + + + + + + +carvekit.ml.wrap.tracer_b7 API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.wrap.tracer_b7

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import pathlib
+from typing import List, Union
+
+import PIL.Image
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from PIL import Image
+
+from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7
+from carvekit.ml.arch.tracerb7.tracer import TracerDecoder
+from carvekit.ml.files.models_loc import tracer_b7_pretrained
+from carvekit.utils.image_utils import load_image, convert_image
+from carvekit.utils.models_utils import get_precision_autocast, cast_network
+from carvekit.utils.pool_utils import thread_pool_processing, batch_generator
+
+__all__ = ["TracerUniversalB7"]
+
+
+class TracerUniversalB7(TracerDecoder):
+    """TRACER B7 model interface"""
+
+    def __init__(
+        self,
+        device="cpu",
+        input_image_size: Union[List[int], int] = 640,
+        batch_size: int = 4,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+        model_path: Union[str, pathlib.Path] = None,
+    ):
+        """
+        Initialize the TRACER model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (Union[List[int], int], default=640): input image size
+            batch_size(int, default=4): the number of images that the neural network processes in one run
+            load_pretrained(bool, default=True): loading pretrained model
+            fp16 (bool, default=False): use fp16 precision
+            model_path (Union[str, pathlib.Path], default=None): path to the model
+            .. note:: REDO
+        """
+        if model_path is None:
+            model_path = tracer_b7_pretrained()
+        super(TracerUniversalB7, self).__init__(
+            encoder=EfficientEncoderB7(),
+            rfb_channel=[32, 64, 128],
+            features_channels=[48, 80, 224, 640],
+        )
+
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+
+        self.transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Resize(self.input_image_size),
+                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+            ]
+        )
+        self.to(device)
+        if load_pretrained:
+            # TODO remove edge detector from weights. It doesn't work well with this model!
+            self.load_state_dict(
+                torch.load(model_path, map_location=self.device), strict=False
+            )
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.FloatTensor: input for neural network
+
+        """
+
+        return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask
+
+        """
+        output = (data.type(torch.FloatTensor).detach().cpu().numpy() * 255.0).astype(
+            np.uint8
+        )
+        output = output.squeeze(0)
+        mask = Image.fromarray(output).convert("L")
+        mask = mask.resize(original_image.size, resample=Image.BILINEAR)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = torch.vstack(
+                    thread_pool_processing(self.data_preprocessing, converted_images)
+                )
+                with torch.no_grad():
+                    batches = batches.to(self.device)
+                    masks = super(TracerDecoder, self).__call__(batches)
+                    masks_cpu = masks.cpu()
+                    del batches, masks
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(
+                        masks_cpu[x], converted_images[x]
+                    ),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+
+        return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class TracerUniversalB7 +(device='cpu', input_image_size:Β Union[List[int],Β int]Β =Β 640, batch_size:Β intΒ =Β 4, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False, model_path:Β Union[str,Β pathlib.Path]Β =Β None) +
+
+

TRACER B7 model interface

+

Initialize the TRACER model

+

Args

+
+
device : Literal[cpu, cuda], default=cpu
+
processing device
+
input_image_size : Union[List[int], int], default=640
+
input image size
+
batch_size(int, default=4): the number of images that the neural network processes in one run
+
load_pretrained(bool, default=True): loading pretrained model
+
fp16 : bool, default=False
+
use fp16 precision
+
model_path : Union[str, pathlib.Path], default=None
+
path to the model
+
+
+

Note: REDO

+
+
+ +Expand source code + +
class TracerUniversalB7(TracerDecoder):
+    """TRACER B7 model interface"""
+
+    def __init__(
+        self,
+        device="cpu",
+        input_image_size: Union[List[int], int] = 640,
+        batch_size: int = 4,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+        model_path: Union[str, pathlib.Path] = None,
+    ):
+        """
+        Initialize the TRACER model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (Union[List[int], int], default=640): input image size
+            batch_size(int, default=4): the number of images that the neural network processes in one run
+            load_pretrained(bool, default=True): loading pretrained model
+            fp16 (bool, default=False): use fp16 precision
+            model_path (Union[str, pathlib.Path], default=None): path to the model
+            .. note:: REDO
+        """
+        if model_path is None:
+            model_path = tracer_b7_pretrained()
+        super(TracerUniversalB7, self).__init__(
+            encoder=EfficientEncoderB7(),
+            rfb_channel=[32, 64, 128],
+            features_channels=[48, 80, 224, 640],
+        )
+
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+
+        self.transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Resize(self.input_image_size),
+                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+            ]
+        )
+        self.to(device)
+        if load_pretrained:
+            # TODO remove edge detector from weights. It doesn't work well with this model!
+            self.load_state_dict(
+                torch.load(model_path, map_location=self.device), strict=False
+            )
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.FloatTensor: input for neural network
+
+        """
+
+        return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask
+
+        """
+        output = (data.type(torch.FloatTensor).detach().cpu().numpy() * 255.0).astype(
+            np.uint8
+        )
+        output = output.squeeze(0)
+        mask = Image.fromarray(output).convert("L")
+        mask = mask.resize(original_image.size, resample=Image.BILINEAR)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = torch.vstack(
+                    thread_pool_processing(self.data_preprocessing, converted_images)
+                )
+                with torch.no_grad():
+                    batches = batches.to(self.device)
+                    masks = super(TracerDecoder, self).__call__(batches)
+                    masks_cpu = masks.cpu()
+                    del batches, masks
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(
+                        masks_cpu[x], converted_images[x]
+                    ),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+
+        return collect_masks
+
+

Ancestors

+ +

Static methods

+
+
+def data_postprocessing(data:Β torch.Tensor, original_image:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data : torch.Tensor
+
output data from neural network
+
original_image : PIL.Image.Image
+
input image which was used for predicted data
+
+

Returns

+
+
PIL.Image.Image
+
Segmentation mask
+
+
+ +Expand source code + +
@staticmethod
+def data_postprocessing(
+    data: torch.Tensor, original_image: PIL.Image.Image
+) -> PIL.Image.Image:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data (torch.Tensor): output data from neural network
+        original_image (PIL.Image.Image): input image which was used for predicted data
+
+    Returns:
+        PIL.Image.Image: Segmentation mask
+
+    """
+    output = (data.type(torch.FloatTensor).detach().cpu().numpy() * 255.0).astype(
+        np.uint8
+    )
+    output = output.squeeze(0)
+    mask = Image.fromarray(output).convert("L")
+    mask = mask.resize(original_image.size, resample=Image.BILINEAR)
+    return mask
+
+
+
+

Methods

+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.FloatTensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data : PIL.Image.Image
+
input image
+
+

Returns

+
+
torch.FloatTensor
+
input for neural network
+
+
+ +Expand source code + +
def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data (PIL.Image.Image): input image
+
+    Returns:
+        torch.FloatTensor: input for neural network
+
+    """
+
+    return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/wrap/u2net.html b/docs/api/carvekit/ml/wrap/u2net.html new file mode 100644 index 0000000..705f732 --- /dev/null +++ b/docs/api/carvekit/ml/wrap/u2net.html @@ -0,0 +1,485 @@ + + + + + + +carvekit.ml.wrap.u2net API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.wrap.u2net

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import pathlib
+import warnings
+
+from typing import List, Union
+import PIL.Image
+import numpy as np
+import torch
+from PIL import Image
+
+from carvekit.ml.arch.u2net.u2net import U2NETArchitecture
+from carvekit.ml.files.models_loc import u2net_full_pretrained
+from carvekit.utils.image_utils import load_image, convert_image
+from carvekit.utils.pool_utils import thread_pool_processing, batch_generator
+
+__all__ = ["U2NET"]
+
+
+class U2NET(U2NETArchitecture):
+    """U^2-Net model interface"""
+
+    def __init__(
+        self,
+        layers_cfg="full",
+        device="cpu",
+        input_image_size: Union[List[int], int] = 320,
+        batch_size: int = 10,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the U2NET model
+
+        Args:
+            layers_cfg: neural network layers configuration
+            device: processing device
+            input_image_size: input image size
+            batch_size: the number of images that the neural network processes in one run
+            load_pretrained: loading pretrained model
+            fp16: use fp16 precision // not supported at this moment.
+
+        """
+        super(U2NET, self).__init__(cfg_type=layers_cfg, out_ch=1)
+        if fp16:
+            warnings.warn("FP16 is not supported at this moment for U2NET model")
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.to(device)
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(u2net_full_pretrained(), map_location=self.device)
+            )
+
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.FloatTensor: input for neural network
+
+        """
+        resized = data.resize(self.input_image_size, resample=3)
+        # noinspection PyTypeChecker
+        resized_arr = np.array(resized, dtype=float)
+        temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+        if np.max(resized_arr) != 0:
+            resized_arr /= np.max(resized_arr)
+        temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+        temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+        temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+        temp_image = temp_image.transpose((2, 0, 1))
+        temp_image = np.expand_dims(temp_image, 0)
+        return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        data = data.unsqueeze(0)
+        mask = data[:, 0, :, :]
+        ma = torch.max(mask)  # Normalizes prediction
+        mi = torch.min(mask)
+        predict = ((mask - mi) / (ma - mi)).squeeze()
+        predict_np = predict.cpu().data.numpy() * 255
+        mask = Image.fromarray(predict_np).convert("L")
+        mask = mask.resize(original_image.size, resample=3)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images: input images
+
+        Returns:
+            segmentation masks as for input images, as PIL.Image.Image instances
+
+        """
+        collect_masks = []
+        for image_batch in batch_generator(images, self.batch_size):
+            converted_images = thread_pool_processing(
+                lambda x: convert_image(load_image(x)), image_batch
+            )
+            batches = torch.vstack(
+                thread_pool_processing(self.data_preprocessing, converted_images)
+            )
+            with torch.no_grad():
+                batches = batches.to(self.device)
+                masks, d2, d3, d4, d5, d6, d7 = super(U2NET, self).__call__(batches)
+                masks_cpu = masks.cpu()
+                del d2, d3, d4, d5, d6, d7, batches, masks
+            masks = thread_pool_processing(
+                lambda x: self.data_postprocessing(masks_cpu[x], converted_images[x]),
+                range(len(converted_images)),
+            )
+            collect_masks += masks
+        return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class U2NET +(layers_cfg='full', device='cpu', input_image_size:Β Union[List[int],Β int]Β =Β 320, batch_size:Β intΒ =Β 10, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False) +
+
+

U^2-Net model interface

+

Initialize the U2NET model

+

Args

+
+
layers_cfg
+
neural network layers configuration
+
device
+
processing device
+
input_image_size
+
input image size
+
batch_size
+
the number of images that the neural network processes in one run
+
load_pretrained
+
loading pretrained model
+
fp16
+
use fp16 precision // not supported at this moment.
+
+
+ +Expand source code + +
class U2NET(U2NETArchitecture):
+    """U^2-Net model interface"""
+
+    def __init__(
+        self,
+        layers_cfg="full",
+        device="cpu",
+        input_image_size: Union[List[int], int] = 320,
+        batch_size: int = 10,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the U2NET model
+
+        Args:
+            layers_cfg: neural network layers configuration
+            device: processing device
+            input_image_size: input image size
+            batch_size: the number of images that the neural network processes in one run
+            load_pretrained: loading pretrained model
+            fp16: use fp16 precision // not supported at this moment.
+
+        """
+        super(U2NET, self).__init__(cfg_type=layers_cfg, out_ch=1)
+        if fp16:
+            warnings.warn("FP16 is not supported at this moment for U2NET model")
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.to(device)
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(u2net_full_pretrained(), map_location=self.device)
+            )
+
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.FloatTensor: input for neural network
+
+        """
+        resized = data.resize(self.input_image_size, resample=3)
+        # noinspection PyTypeChecker
+        resized_arr = np.array(resized, dtype=float)
+        temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+        if np.max(resized_arr) != 0:
+            resized_arr /= np.max(resized_arr)
+        temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+        temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+        temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+        temp_image = temp_image.transpose((2, 0, 1))
+        temp_image = np.expand_dims(temp_image, 0)
+        return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        data = data.unsqueeze(0)
+        mask = data[:, 0, :, :]
+        ma = torch.max(mask)  # Normalizes prediction
+        mi = torch.min(mask)
+        predict = ((mask - mi) / (ma - mi)).squeeze()
+        predict_np = predict.cpu().data.numpy() * 255
+        mask = Image.fromarray(predict_np).convert("L")
+        mask = mask.resize(original_image.size, resample=3)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images: input images
+
+        Returns:
+            segmentation masks as for input images, as PIL.Image.Image instances
+
+        """
+        collect_masks = []
+        for image_batch in batch_generator(images, self.batch_size):
+            converted_images = thread_pool_processing(
+                lambda x: convert_image(load_image(x)), image_batch
+            )
+            batches = torch.vstack(
+                thread_pool_processing(self.data_preprocessing, converted_images)
+            )
+            with torch.no_grad():
+                batches = batches.to(self.device)
+                masks, d2, d3, d4, d5, d6, d7 = super(U2NET, self).__call__(batches)
+                masks_cpu = masks.cpu()
+                del d2, d3, d4, d5, d6, d7, batches, masks
+            masks = thread_pool_processing(
+                lambda x: self.data_postprocessing(masks_cpu[x], converted_images[x]),
+                range(len(converted_images)),
+            )
+            collect_masks += masks
+        return collect_masks
+
+

Ancestors

+ +

Static methods

+
+
+def data_postprocessing(data:Β torch.Tensor, original_image:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data : torch.Tensor
+
output data from neural network
+
original_image : PIL.Image.Image
+
input image which was used for predicted data
+
+

Returns

+
+
PIL.Image.Image
+
Segmentation mask as PIL Image instance
+
+
+ +Expand source code + +
@staticmethod
+def data_postprocessing(
+    data: torch.Tensor, original_image: PIL.Image.Image
+) -> PIL.Image.Image:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data (torch.Tensor): output data from neural network
+        original_image (PIL.Image.Image): input image which was used for predicted data
+
+    Returns:
+        PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+    """
+    data = data.unsqueeze(0)
+    mask = data[:, 0, :, :]
+    ma = torch.max(mask)  # Normalizes prediction
+    mi = torch.min(mask)
+    predict = ((mask - mi) / (ma - mi)).squeeze()
+    predict_np = predict.cpu().data.numpy() * 255
+    mask = Image.fromarray(predict_np).convert("L")
+    mask = mask.resize(original_image.size, resample=3)
+    return mask
+
+
+
+

Methods

+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.FloatTensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data : PIL.Image.Image
+
input image
+
+

Returns

+
+
torch.FloatTensor
+
input for neural network
+
+
+ +Expand source code + +
def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data (PIL.Image.Image): input image
+
+    Returns:
+        torch.FloatTensor: input for neural network
+
+    """
+    resized = data.resize(self.input_image_size, resample=3)
+    # noinspection PyTypeChecker
+    resized_arr = np.array(resized, dtype=float)
+    temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+    if np.max(resized_arr) != 0:
+        resized_arr /= np.max(resized_arr)
+    temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+    temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+    temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+    temp_image = temp_image.transpose((2, 0, 1))
+    temp_image = np.expand_dims(temp_image, 0)
+    return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/ml/wrap/yolov4.html b/docs/api/carvekit/ml/wrap/yolov4.html new file mode 100644 index 0000000..25d9a6a --- /dev/null +++ b/docs/api/carvekit/ml/wrap/yolov4.html @@ -0,0 +1,881 @@ + + + + + + +carvekit.ml.wrap.yolov4 API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.ml.wrap.yolov4

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+
+import pathlib
+
+import PIL.Image
+import PIL.Image
+import numpy as np
+import pydantic
+import torch
+from torch.autograd import Variable
+from typing import List, Union
+
+from carvekit.ml.arch.yolov4.models import Yolov4
+from carvekit.ml.arch.yolov4.utils import post_processing
+from carvekit.ml.files.models_loc import yolov4_coco_pretrained
+from carvekit.utils.image_utils import load_image, convert_image
+from carvekit.utils.models_utils import get_precision_autocast, cast_network
+from carvekit.utils.pool_utils import thread_pool_processing, batch_generator
+
+__all__ = ["YoloV4_COCO", "SimplifiedYoloV4"]
+
+
+class Object(pydantic.BaseModel):
+    """Object class"""
+
+    class_name: str
+    confidence: float
+    x1: int
+    y1: int
+    x2: int
+    y2: int
+
+
+class YoloV4_COCO(Yolov4):
+    """YoloV4 COCO model wrapper"""
+
+    def __init__(
+        self,
+        n_classes: int = 80,
+        device="cpu",
+        classes: List[str] = None,
+        input_image_size: Union[List[int], int] = 608,
+        batch_size: int = 4,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+        model_path: Union[str, pathlib.Path] = None,
+    ):
+        """
+        Initialize the YoloV4 COCO.
+
+        Args:
+            n_classes: number of classes
+            device: processing device
+            input_image_size: input image size
+            batch_size: the number of images that the neural network processes in one run
+            fp16: use fp16 precision
+            model_path: path to model weights
+            load_pretrained: load pretrained weights
+        """
+        if model_path is None:
+            model_path = yolov4_coco_pretrained()
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+
+        if load_pretrained:
+            state_dict = torch.load(model_path, map_location="cpu")
+            self.classes = state_dict["classes"]
+            super().__init__(n_classes=len(state_dict["classes"]), inference=True)
+            self.load_state_dict(state_dict["state"])
+        else:
+            self.classes = classes
+            super().__init__(n_classes=n_classes, inference=True)
+
+        self.to(device)
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data: input image
+
+        Returns:
+            input for neural network
+
+        """
+        image = data.resize(self.input_image_size)
+        # noinspection PyTypeChecker
+        image = np.array(image).astype(np.float32)
+        image = image.transpose((2, 0, 1))
+        image = image / 255.0
+        image = torch.from_numpy(image).float()
+        return torch.unsqueeze(image, 0).type(torch.FloatTensor)
+
+    def data_postprocessing(
+        self, data: List[torch.FloatTensor], images: List[PIL.Image.Image]
+    ) -> List[Object]:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+            images: input images
+
+
+        Returns:
+            list of objects for each image
+
+        """
+        output = post_processing(0.4, 0.6, data)
+        images_objects = []
+        for image_idx, image_objects in enumerate(output):
+            image_size = images[image_idx].size
+            objects = []
+            for obj in image_objects:
+                objects.append(
+                    Object(
+                        class_name=self.classes[obj[6]],
+                        confidence=obj[5],
+                        x1=int(obj[0] * image_size[0]),
+                        y1=int(obj[1] * image_size[1]),
+                        x2=int(obj[2] * image_size[0]),
+                        y2=int(obj[3] * image_size[1]),
+                    )
+                )
+            images_objects.append(objects)
+
+        return images_objects
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[List[Object]]:
+        """
+        Passes input images though neural network
+
+        Args:
+            images: input images
+
+        Returns:
+            list of objects for each image
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = torch.vstack(
+                    thread_pool_processing(self.data_preprocessing, converted_images)
+                )
+                with torch.no_grad():
+                    batches = Variable(batches).to(self.device)
+                    out = super().__call__(batches)
+                    out_cpu = [out_i.cpu() for out_i in out]
+                    del batches, out
+                out = self.data_postprocessing(out_cpu, converted_images)
+                collect_masks += out
+
+        return collect_masks
+
+
+class SimplifiedYoloV4(YoloV4_COCO):
+    """
+    The YoloV4 COCO classifier, but classifies only 7 supercategories.
+
+    human - Scenes of people, such as portrait photographs
+    animals - Scenes with animals
+    objects - Scenes with normal objects
+    cars - Scenes with cars
+    other - Other scenes
+    """
+
+    db = {
+        "human": ["person"],
+        "animals": [
+            "bird",
+            "cat",
+            "dog",
+            "horse",
+            "sheep",
+            "cow",
+            "elephant",
+            "bear",
+            "zebra",
+            "giraffe",
+        ],
+        "cars": [
+            "car",
+            "motorbike",
+            "bus",
+            "truck",
+        ],
+        "objects": [
+            "bicycle",
+            "traffic light",
+            "fire hydrant",
+            "stop sign",
+            "parking meter",
+            "bench",
+            "backpack",
+            "umbrella",
+            "handbag",
+            "tie",
+            "suitcase",
+            "frisbee",
+            "skis",
+            "snowboard",
+            "sports ball",
+            "kite",
+            "baseball bat",
+            "baseball glove",
+            "skateboard",
+            "surfboard",
+            "tennis racket",
+            "bottle",
+            "wine glass",
+            "cup",
+            "fork",
+            "knife",
+            "spoon",
+            "bowl",
+            "banana",
+            "apple",
+            "sandwich",
+            "orange",
+            "broccoli",
+            "carrot",
+            "hot dog",
+            "pizza",
+            "donut",
+            "cake",
+            "chair",
+            "sofa",
+            "pottedplant",
+            "bed",
+            "diningtable",
+            "toilet",
+            "tvmonitor",
+            "laptop",
+            "mouse",
+            "remote",
+            "keyboard",
+            "cell phone",
+            "microwave",
+            "oven",
+            "toaster",
+            "sink",
+            "refrigerator",
+            "book",
+            "clock",
+            "vase",
+            "scissors",
+            "teddy bear",
+            "hair drier",
+            "toothbrush",
+        ],
+        "other": ["aeroplane", "train", "boat"],
+    }
+
+    def data_postprocessing(
+        self, data: List[torch.FloatTensor], images: List[PIL.Image.Image]
+    ) -> List[List[str]]:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+            images: input images
+        """
+        objects = super().data_postprocessing(data, images)
+        new_output = []
+
+        for image_objects in objects:
+            new_objects = []
+            for obj in image_objects:
+                for key, values in list(self.db.items()):
+                    if obj.class_name in values:
+                        new_objects.append(key)  # We don't need bbox at this moment
+            new_output.append(new_objects)
+
+        return new_output
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class SimplifiedYoloV4 +(n_classes:Β intΒ =Β 80, device='cpu', classes:Β List[str]Β =Β None, input_image_size:Β Union[List[int],Β int]Β =Β 608, batch_size:Β intΒ =Β 4, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False, model_path:Β Union[str,Β pathlib.Path]Β =Β None) +
+
+

The YoloV4 COCO classifier, but classifies only 7 supercategories.

+

human - Scenes of people, such as portrait photographs +animals - Scenes with animals +objects - Scenes with normal objects +cars - Scenes with cars +other - Other scenes

+

Initialize the YoloV4 COCO.

+

Args

+
+
n_classes
+
number of classes
+
device
+
processing device
+
input_image_size
+
input image size
+
batch_size
+
the number of images that the neural network processes in one run
+
fp16
+
use fp16 precision
+
model_path
+
path to model weights
+
load_pretrained
+
load pretrained weights
+
+
+ +Expand source code + +
class SimplifiedYoloV4(YoloV4_COCO):
+    """
+    The YoloV4 COCO classifier, but classifies only 7 supercategories.
+
+    human - Scenes of people, such as portrait photographs
+    animals - Scenes with animals
+    objects - Scenes with normal objects
+    cars - Scenes with cars
+    other - Other scenes
+    """
+
+    db = {
+        "human": ["person"],
+        "animals": [
+            "bird",
+            "cat",
+            "dog",
+            "horse",
+            "sheep",
+            "cow",
+            "elephant",
+            "bear",
+            "zebra",
+            "giraffe",
+        ],
+        "cars": [
+            "car",
+            "motorbike",
+            "bus",
+            "truck",
+        ],
+        "objects": [
+            "bicycle",
+            "traffic light",
+            "fire hydrant",
+            "stop sign",
+            "parking meter",
+            "bench",
+            "backpack",
+            "umbrella",
+            "handbag",
+            "tie",
+            "suitcase",
+            "frisbee",
+            "skis",
+            "snowboard",
+            "sports ball",
+            "kite",
+            "baseball bat",
+            "baseball glove",
+            "skateboard",
+            "surfboard",
+            "tennis racket",
+            "bottle",
+            "wine glass",
+            "cup",
+            "fork",
+            "knife",
+            "spoon",
+            "bowl",
+            "banana",
+            "apple",
+            "sandwich",
+            "orange",
+            "broccoli",
+            "carrot",
+            "hot dog",
+            "pizza",
+            "donut",
+            "cake",
+            "chair",
+            "sofa",
+            "pottedplant",
+            "bed",
+            "diningtable",
+            "toilet",
+            "tvmonitor",
+            "laptop",
+            "mouse",
+            "remote",
+            "keyboard",
+            "cell phone",
+            "microwave",
+            "oven",
+            "toaster",
+            "sink",
+            "refrigerator",
+            "book",
+            "clock",
+            "vase",
+            "scissors",
+            "teddy bear",
+            "hair drier",
+            "toothbrush",
+        ],
+        "other": ["aeroplane", "train", "boat"],
+    }
+
+    def data_postprocessing(
+        self, data: List[torch.FloatTensor], images: List[PIL.Image.Image]
+    ) -> List[List[str]]:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+            images: input images
+        """
+        objects = super().data_postprocessing(data, images)
+        new_output = []
+
+        for image_objects in objects:
+            new_objects = []
+            for obj in image_objects:
+                for key, values in list(self.db.items()):
+                    if obj.class_name in values:
+                        new_objects.append(key)  # We don't need bbox at this moment
+            new_output.append(new_objects)
+
+        return new_output
+
+

Ancestors

+ +

Class variables

+
+
var db
+
+
+
+
+

Methods

+
+
+def data_postprocessing(self, data:Β List[torch.FloatTensor], images:Β List[PIL.Image.Image]) ‑>Β List[List[str]] +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data
+
output data from neural network
+
images
+
input images
+
+
+ +Expand source code + +
def data_postprocessing(
+    self, data: List[torch.FloatTensor], images: List[PIL.Image.Image]
+) -> List[List[str]]:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data: output data from neural network
+        images: input images
+    """
+    objects = super().data_postprocessing(data, images)
+    new_output = []
+
+    for image_objects in objects:
+        new_objects = []
+        for obj in image_objects:
+            for key, values in list(self.db.items()):
+                if obj.class_name in values:
+                    new_objects.append(key)  # We don't need bbox at this moment
+        new_output.append(new_objects)
+
+    return new_output
+
+
+
+

Inherited members

+ +
+
+class YoloV4_COCO +(n_classes:Β intΒ =Β 80, device='cpu', classes:Β List[str]Β =Β None, input_image_size:Β Union[List[int],Β int]Β =Β 608, batch_size:Β intΒ =Β 4, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False, model_path:Β Union[str,Β pathlib.Path]Β =Β None) +
+
+

YoloV4 COCO model wrapper

+

Initialize the YoloV4 COCO.

+

Args

+
+
n_classes
+
number of classes
+
device
+
processing device
+
input_image_size
+
input image size
+
batch_size
+
the number of images that the neural network processes in one run
+
fp16
+
use fp16 precision
+
model_path
+
path to model weights
+
load_pretrained
+
load pretrained weights
+
+
+ +Expand source code + +
class YoloV4_COCO(Yolov4):
+    """YoloV4 COCO model wrapper"""
+
+    def __init__(
+        self,
+        n_classes: int = 80,
+        device="cpu",
+        classes: List[str] = None,
+        input_image_size: Union[List[int], int] = 608,
+        batch_size: int = 4,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+        model_path: Union[str, pathlib.Path] = None,
+    ):
+        """
+        Initialize the YoloV4 COCO.
+
+        Args:
+            n_classes: number of classes
+            device: processing device
+            input_image_size: input image size
+            batch_size: the number of images that the neural network processes in one run
+            fp16: use fp16 precision
+            model_path: path to model weights
+            load_pretrained: load pretrained weights
+        """
+        if model_path is None:
+            model_path = yolov4_coco_pretrained()
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+
+        if load_pretrained:
+            state_dict = torch.load(model_path, map_location="cpu")
+            self.classes = state_dict["classes"]
+            super().__init__(n_classes=len(state_dict["classes"]), inference=True)
+            self.load_state_dict(state_dict["state"])
+        else:
+            self.classes = classes
+            super().__init__(n_classes=n_classes, inference=True)
+
+        self.to(device)
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data: input image
+
+        Returns:
+            input for neural network
+
+        """
+        image = data.resize(self.input_image_size)
+        # noinspection PyTypeChecker
+        image = np.array(image).astype(np.float32)
+        image = image.transpose((2, 0, 1))
+        image = image / 255.0
+        image = torch.from_numpy(image).float()
+        return torch.unsqueeze(image, 0).type(torch.FloatTensor)
+
+    def data_postprocessing(
+        self, data: List[torch.FloatTensor], images: List[PIL.Image.Image]
+    ) -> List[Object]:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+            images: input images
+
+
+        Returns:
+            list of objects for each image
+
+        """
+        output = post_processing(0.4, 0.6, data)
+        images_objects = []
+        for image_idx, image_objects in enumerate(output):
+            image_size = images[image_idx].size
+            objects = []
+            for obj in image_objects:
+                objects.append(
+                    Object(
+                        class_name=self.classes[obj[6]],
+                        confidence=obj[5],
+                        x1=int(obj[0] * image_size[0]),
+                        y1=int(obj[1] * image_size[1]),
+                        x2=int(obj[2] * image_size[0]),
+                        y2=int(obj[3] * image_size[1]),
+                    )
+                )
+            images_objects.append(objects)
+
+        return images_objects
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[List[Object]]:
+        """
+        Passes input images though neural network
+
+        Args:
+            images: input images
+
+        Returns:
+            list of objects for each image
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = torch.vstack(
+                    thread_pool_processing(self.data_preprocessing, converted_images)
+                )
+                with torch.no_grad():
+                    batches = Variable(batches).to(self.device)
+                    out = super().__call__(batches)
+                    out_cpu = [out_i.cpu() for out_i in out]
+                    del batches, out
+                out = self.data_postprocessing(out_cpu, converted_images)
+                collect_masks += out
+
+        return collect_masks
+
+

Ancestors

+
    +
  • Yolov4
  • +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Methods

+
+
+def data_postprocessing(self, data:Β List[torch.FloatTensor], images:Β List[PIL.Image.Image]) ‑>Β List[carvekit.ml.wrap.yolov4.Object] +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data
+
output data from neural network
+
images
+
input images
+
+

Returns

+

list of objects for each image

+
+ +Expand source code + +
def data_postprocessing(
+    self, data: List[torch.FloatTensor], images: List[PIL.Image.Image]
+) -> List[Object]:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data: output data from neural network
+        images: input images
+
+
+    Returns:
+        list of objects for each image
+
+    """
+    output = post_processing(0.4, 0.6, data)
+    images_objects = []
+    for image_idx, image_objects in enumerate(output):
+        image_size = images[image_idx].size
+        objects = []
+        for obj in image_objects:
+            objects.append(
+                Object(
+                    class_name=self.classes[obj[6]],
+                    confidence=obj[5],
+                    x1=int(obj[0] * image_size[0]),
+                    y1=int(obj[1] * image_size[1]),
+                    x2=int(obj[2] * image_size[0]),
+                    y2=int(obj[3] * image_size[1]),
+                )
+            )
+        images_objects.append(objects)
+
+    return images_objects
+
+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.FloatTensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data
+
input image
+
+

Returns

+

input for neural network

+
+ +Expand source code + +
def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data: input image
+
+    Returns:
+        input for neural network
+
+    """
+    image = data.resize(self.input_image_size)
+    # noinspection PyTypeChecker
+    image = np.array(image).astype(np.float32)
+    image = image.transpose((2, 0, 1))
+    image = image / 255.0
+    image = torch.from_numpy(image).float()
+    return torch.unsqueeze(image, 0).type(torch.FloatTensor)
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/pipelines/index.html b/docs/api/carvekit/pipelines/index.html new file mode 100644 index 0000000..52e92a9 --- /dev/null +++ b/docs/api/carvekit/pipelines/index.html @@ -0,0 +1,70 @@ + + + + + + +carvekit.pipelines API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.pipelines

+
+
+
+
+

Sub-modules

+
+
carvekit.pipelines.postprocessing
+
+
+
+
carvekit.pipelines.preprocessing
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/pipelines/postprocessing.html b/docs/api/carvekit/pipelines/postprocessing.html new file mode 100644 index 0000000..354fb9a --- /dev/null +++ b/docs/api/carvekit/pipelines/postprocessing.html @@ -0,0 +1,226 @@ + + + + + + +carvekit.pipelines.postprocessing API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.pipelines.postprocessing

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from carvekit.ml.wrap.fba_matting import FBAMatting
+from typing import Union, List
+from PIL import Image
+from pathlib import Path
+from carvekit.trimap.cv_gen import CV2TrimapGenerator
+from carvekit.trimap.generator import TrimapGenerator
+from carvekit.utils.mask_utils import apply_mask
+from carvekit.utils.pool_utils import thread_pool_processing
+from carvekit.utils.image_utils import load_image, convert_image
+
+__all__ = ["MattingMethod"]
+
+
+class MattingMethod:
+    """
+    Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap.
+    Neural network for matting performs accurate object edge detection by using a special map called trimap,
+    with unknown area that we scan for boundary, already known general object area and the background."""
+
+    def __init__(
+        self,
+        matting_module: Union[FBAMatting],
+        trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
+        device="cpu",
+    ):
+        """
+        Initializes Matting Method class.
+
+        Args:
+        - `matting_module`: Initialized matting neural network class
+        - `trimap_generator`: Initialized trimap generator class
+        - `device`: Processing device used for applying mask to image
+        """
+        self.device = device
+        self.matting_module = matting_module
+        self.trimap_generator = trimap_generator
+
+    def __call__(
+        self,
+        images: List[Union[str, Path, Image.Image]],
+        masks: List[Union[str, Path, Image.Image]],
+    ):
+        """
+        Passes data through apply_mask function
+
+        Args:
+        - `images`: list of images
+        - `masks`: list pf masks
+
+        Returns:
+        list of images
+        """
+        if len(images) != len(masks):
+            raise ValueError("Images and Masks lists should have same length!")
+        images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
+        masks = thread_pool_processing(
+            lambda x: convert_image(load_image(x), mode="L"), masks
+        )
+        trimaps = thread_pool_processing(
+            lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]),
+            range(len(images)),
+        )
+        alpha = self.matting_module(images=images, trimaps=trimaps)
+        return list(
+            map(
+                lambda x: apply_mask(
+                    image=images[x], mask=alpha[x], device=self.device
+                ),
+                range(len(images)),
+            )
+        )
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class MattingMethod +(matting_module:Β FBAMatting, trimap_generator:Β Union[TrimapGenerator,Β CV2TrimapGenerator], device='cpu') +
+
+

Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap. +Neural network for matting performs accurate object edge detection by using a special map called trimap, +with unknown area that we scan for boundary, already known general object area and the background.

+

Initializes Matting Method class.

+

Args: +- matting_module: Initialized matting neural network class +- trimap_generator: Initialized trimap generator class +- device: Processing device used for applying mask to image

+
+ +Expand source code + +
class MattingMethod:
+    """
+    Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap.
+    Neural network for matting performs accurate object edge detection by using a special map called trimap,
+    with unknown area that we scan for boundary, already known general object area and the background."""
+
+    def __init__(
+        self,
+        matting_module: Union[FBAMatting],
+        trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
+        device="cpu",
+    ):
+        """
+        Initializes Matting Method class.
+
+        Args:
+        - `matting_module`: Initialized matting neural network class
+        - `trimap_generator`: Initialized trimap generator class
+        - `device`: Processing device used for applying mask to image
+        """
+        self.device = device
+        self.matting_module = matting_module
+        self.trimap_generator = trimap_generator
+
+    def __call__(
+        self,
+        images: List[Union[str, Path, Image.Image]],
+        masks: List[Union[str, Path, Image.Image]],
+    ):
+        """
+        Passes data through apply_mask function
+
+        Args:
+        - `images`: list of images
+        - `masks`: list pf masks
+
+        Returns:
+        list of images
+        """
+        if len(images) != len(masks):
+            raise ValueError("Images and Masks lists should have same length!")
+        images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
+        masks = thread_pool_processing(
+            lambda x: convert_image(load_image(x), mode="L"), masks
+        )
+        trimaps = thread_pool_processing(
+            lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]),
+            range(len(images)),
+        )
+        alpha = self.matting_module(images=images, trimaps=trimaps)
+        return list(
+            map(
+                lambda x: apply_mask(
+                    image=images[x], mask=alpha[x], device=self.device
+                ),
+                range(len(images)),
+            )
+        )
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/pipelines/postprocessing/casmatting.html b/docs/api/carvekit/pipelines/postprocessing/casmatting.html new file mode 100644 index 0000000..16732cc --- /dev/null +++ b/docs/api/carvekit/pipelines/postprocessing/casmatting.html @@ -0,0 +1,245 @@ + + + + + + +carvekit.pipelines.postprocessing.casmatting API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.pipelines.postprocessing.casmatting

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from carvekit.ml.wrap.fba_matting import FBAMatting
+from carvekit.ml.wrap.cascadepsp import CascadePSP
+from typing import Union, List
+from PIL import Image
+from pathlib import Path
+from carvekit.trimap.cv_gen import CV2TrimapGenerator
+from carvekit.trimap.generator import TrimapGenerator
+from carvekit.utils.mask_utils import apply_mask
+from carvekit.utils.pool_utils import thread_pool_processing
+from carvekit.utils.image_utils import load_image, convert_image
+
+__all__ = ["CasMattingMethod"]
+
+
+class CasMattingMethod:
+    """
+    Improve segmentation quality by refining segmentation with the CascadePSP model
+    and post-processing the segmentation with the FBAMatting model
+    """
+
+    def __init__(
+        self,
+        refining_module: Union[CascadePSP],
+        matting_module: Union[FBAMatting],
+        trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
+        device="cpu",
+    ):
+        """
+        Initializes CasMattingMethod class.
+
+        Args:
+            refining_module: Initialized refining network
+            matting_module: Initialized matting neural network class
+            trimap_generator: Initialized trimap generator class
+            device: Processing device used for applying mask to image
+        """
+        self.device = device
+        self.refining_module = refining_module
+        self.matting_module = matting_module
+        self.trimap_generator = trimap_generator
+
+    def __call__(
+        self,
+        images: List[Union[str, Path, Image.Image]],
+        masks: List[Union[str, Path, Image.Image]],
+    ):
+        """
+        Passes data through apply_mask function
+
+        Args:
+            images: list of images
+            masks: list pf masks
+
+        Returns:
+            list of images
+        """
+        if len(images) != len(masks):
+            raise ValueError("Images and Masks lists should have same length!")
+        images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
+        masks = thread_pool_processing(
+            lambda x: convert_image(load_image(x), mode="L"), masks
+        )
+        refined_masks = self.refining_module(images, masks)
+        trimaps = thread_pool_processing(
+            lambda x: self.trimap_generator(
+                original_image=images[x], mask=refined_masks[x]
+            ),
+            range(len(images)),
+        )
+        alpha = self.matting_module(images=images, trimaps=trimaps)
+        return list(
+            map(
+                lambda x: apply_mask(
+                    image=images[x], mask=alpha[x], device=self.device
+                ),
+                range(len(images)),
+            )
+        )
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class CasMattingMethod +(refining_module:Β CascadePSP, matting_module:Β FBAMatting, trimap_generator:Β Union[TrimapGenerator,Β CV2TrimapGenerator], device='cpu') +
+
+

Improve segmentation quality by refining segmentation with the CascadePSP model +and post-processing the segmentation with the FBAMatting model

+

Initializes CasMattingMethod class.

+

Args

+
+
refining_module
+
Initialized refining network
+
matting_module
+
Initialized matting neural network class
+
trimap_generator
+
Initialized trimap generator class
+
device
+
Processing device used for applying mask to image
+
+
+ +Expand source code + +
class CasMattingMethod:
+    """
+    Improve segmentation quality by refining segmentation with the CascadePSP model
+    and post-processing the segmentation with the FBAMatting model
+    """
+
+    def __init__(
+        self,
+        refining_module: Union[CascadePSP],
+        matting_module: Union[FBAMatting],
+        trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
+        device="cpu",
+    ):
+        """
+        Initializes CasMattingMethod class.
+
+        Args:
+            refining_module: Initialized refining network
+            matting_module: Initialized matting neural network class
+            trimap_generator: Initialized trimap generator class
+            device: Processing device used for applying mask to image
+        """
+        self.device = device
+        self.refining_module = refining_module
+        self.matting_module = matting_module
+        self.trimap_generator = trimap_generator
+
+    def __call__(
+        self,
+        images: List[Union[str, Path, Image.Image]],
+        masks: List[Union[str, Path, Image.Image]],
+    ):
+        """
+        Passes data through apply_mask function
+
+        Args:
+            images: list of images
+            masks: list pf masks
+
+        Returns:
+            list of images
+        """
+        if len(images) != len(masks):
+            raise ValueError("Images and Masks lists should have same length!")
+        images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
+        masks = thread_pool_processing(
+            lambda x: convert_image(load_image(x), mode="L"), masks
+        )
+        refined_masks = self.refining_module(images, masks)
+        trimaps = thread_pool_processing(
+            lambda x: self.trimap_generator(
+                original_image=images[x], mask=refined_masks[x]
+            ),
+            range(len(images)),
+        )
+        alpha = self.matting_module(images=images, trimaps=trimaps)
+        return list(
+            map(
+                lambda x: apply_mask(
+                    image=images[x], mask=alpha[x], device=self.device
+                ),
+                range(len(images)),
+            )
+        )
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/pipelines/postprocessing/index.html b/docs/api/carvekit/pipelines/postprocessing/index.html new file mode 100644 index 0000000..18d0b62 --- /dev/null +++ b/docs/api/carvekit/pipelines/postprocessing/index.html @@ -0,0 +1,81 @@ + + + + + + +carvekit.pipelines.postprocessing API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.pipelines.postprocessing

+
+
+
+ +Expand source code + +
from carvekit.pipelines.postprocessing.matting import MattingMethod
+from carvekit.pipelines.postprocessing.casmatting import CasMattingMethod
+
+
+
+

Sub-modules

+
+
carvekit.pipelines.postprocessing.casmatting
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
carvekit.pipelines.postprocessing.matting
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/pipelines/postprocessing/matting.html b/docs/api/carvekit/pipelines/postprocessing/matting.html new file mode 100644 index 0000000..677908e --- /dev/null +++ b/docs/api/carvekit/pipelines/postprocessing/matting.html @@ -0,0 +1,226 @@ + + + + + + +carvekit.pipelines.postprocessing.matting API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.pipelines.postprocessing.matting

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from carvekit.ml.wrap.fba_matting import FBAMatting
+from typing import Union, List
+from PIL import Image
+from pathlib import Path
+from carvekit.trimap.cv_gen import CV2TrimapGenerator
+from carvekit.trimap.generator import TrimapGenerator
+from carvekit.utils.mask_utils import apply_mask
+from carvekit.utils.pool_utils import thread_pool_processing
+from carvekit.utils.image_utils import load_image, convert_image
+
+__all__ = ["MattingMethod"]
+
+
+class MattingMethod:
+    """
+    Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap.
+    Neural network for matting performs accurate object edge detection by using a special map called trimap,
+    with unknown area that we scan for boundary, already known general object area and the background."""
+
+    def __init__(
+        self,
+        matting_module: Union[FBAMatting],
+        trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
+        device="cpu",
+    ):
+        """
+        Initializes Matting Method class.
+
+        Args:
+        - `matting_module`: Initialized matting neural network class
+        - `trimap_generator`: Initialized trimap generator class
+        - `device`: Processing device used for applying mask to image
+        """
+        self.device = device
+        self.matting_module = matting_module
+        self.trimap_generator = trimap_generator
+
+    def __call__(
+        self,
+        images: List[Union[str, Path, Image.Image]],
+        masks: List[Union[str, Path, Image.Image]],
+    ):
+        """
+        Passes data through apply_mask function
+
+        Args:
+        - `images`: list of images
+        - `masks`: list pf masks
+
+        Returns:
+        list of images
+        """
+        if len(images) != len(masks):
+            raise ValueError("Images and Masks lists should have same length!")
+        images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
+        masks = thread_pool_processing(
+            lambda x: convert_image(load_image(x), mode="L"), masks
+        )
+        trimaps = thread_pool_processing(
+            lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]),
+            range(len(images)),
+        )
+        alpha = self.matting_module(images=images, trimaps=trimaps)
+        return list(
+            map(
+                lambda x: apply_mask(
+                    image=images[x], mask=alpha[x], device=self.device
+                ),
+                range(len(images)),
+            )
+        )
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class MattingMethod +(matting_module:Β FBAMatting, trimap_generator:Β Union[TrimapGenerator,Β CV2TrimapGenerator], device='cpu') +
+
+

Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap. +Neural network for matting performs accurate object edge detection by using a special map called trimap, +with unknown area that we scan for boundary, already known general object area and the background.

+

Initializes Matting Method class.

+

Args: +- matting_module: Initialized matting neural network class +- trimap_generator: Initialized trimap generator class +- device: Processing device used for applying mask to image

+
+ +Expand source code + +
class MattingMethod:
+    """
+    Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap.
+    Neural network for matting performs accurate object edge detection by using a special map called trimap,
+    with unknown area that we scan for boundary, already known general object area and the background."""
+
+    def __init__(
+        self,
+        matting_module: Union[FBAMatting],
+        trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
+        device="cpu",
+    ):
+        """
+        Initializes Matting Method class.
+
+        Args:
+        - `matting_module`: Initialized matting neural network class
+        - `trimap_generator`: Initialized trimap generator class
+        - `device`: Processing device used for applying mask to image
+        """
+        self.device = device
+        self.matting_module = matting_module
+        self.trimap_generator = trimap_generator
+
+    def __call__(
+        self,
+        images: List[Union[str, Path, Image.Image]],
+        masks: List[Union[str, Path, Image.Image]],
+    ):
+        """
+        Passes data through apply_mask function
+
+        Args:
+        - `images`: list of images
+        - `masks`: list pf masks
+
+        Returns:
+        list of images
+        """
+        if len(images) != len(masks):
+            raise ValueError("Images and Masks lists should have same length!")
+        images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
+        masks = thread_pool_processing(
+            lambda x: convert_image(load_image(x), mode="L"), masks
+        )
+        trimaps = thread_pool_processing(
+            lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]),
+            range(len(images)),
+        )
+        alpha = self.matting_module(images=images, trimaps=trimaps)
+        return list(
+            map(
+                lambda x: apply_mask(
+                    image=images[x], mask=alpha[x], device=self.device
+                ),
+                range(len(images)),
+            )
+        )
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/pipelines/preprocessing.html b/docs/api/carvekit/pipelines/preprocessing.html new file mode 100644 index 0000000..5e6f830 --- /dev/null +++ b/docs/api/carvekit/pipelines/preprocessing.html @@ -0,0 +1,127 @@ + + + + + + +carvekit.pipelines.preprocessing API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.pipelines.preprocessing

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from pathlib import Path
+from typing import Union, List
+
+from PIL import Image
+
+__all__ = ["PreprocessingStub"]
+
+
+class PreprocessingStub:
+    """Stub for future preprocessing methods"""
+
+    def __call__(self, interface, images: List[Union[str, Path, Image.Image]]):
+        """
+        Passes data though `interface.segmentation_pipeline()` method
+
+        Args:
+        - `interface`: Interface instance
+        - `images`: list of images
+
+        Returns:
+            the result of passing data through segmentation_pipeline method of interface
+        """
+        return interface.segmentation_pipeline(images=images)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class PreprocessingStub +
+
+

Stub for future preprocessing methods

+
+ +Expand source code + +
class PreprocessingStub:
+    """Stub for future preprocessing methods"""
+
+    def __call__(self, interface, images: List[Union[str, Path, Image.Image]]):
+        """
+        Passes data though `interface.segmentation_pipeline()` method
+
+        Args:
+        - `interface`: Interface instance
+        - `images`: list of images
+
+        Returns:
+            the result of passing data through segmentation_pipeline method of interface
+        """
+        return interface.segmentation_pipeline(images=images)
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/pipelines/preprocessing/autoscene.html b/docs/api/carvekit/pipelines/preprocessing/autoscene.html new file mode 100644 index 0000000..cddfecd --- /dev/null +++ b/docs/api/carvekit/pipelines/preprocessing/autoscene.html @@ -0,0 +1,279 @@ + + + + + + +carvekit.pipelines.preprocessing.autoscene API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.pipelines.preprocessing.autoscene

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from pathlib import Path
+
+from PIL import Image
+from typing import Union, List
+
+from carvekit.ml.wrap.scene_classifier import SceneClassifier
+from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
+from carvekit.ml.wrap.u2net import U2NET
+
+__all__ = ["AutoScene"]
+
+
+class AutoScene:
+    """AutoScene preprocessing method"""
+
+    def __init__(self, scene_classifier: SceneClassifier):
+        """
+        Args:
+            scene_classifier: SceneClassifier instance
+        """
+        self.scene_classifier = scene_classifier
+
+    @staticmethod
+    def select_net(scene: str):
+        """
+        Selects the network to be used for segmentation based on the detected scene
+
+        Args:
+            scene: scene name
+        """
+        if scene == "hard":
+            return TracerUniversalB7
+        elif scene == "soft":
+            return U2NET
+        elif scene == "digital":
+            return TracerUniversalB7  # TODO: not implemented yet
+
+    def __call__(self, interface, images: List[Union[str, Path, Image.Image]]):
+        """
+        Automatically detects the scene and selects the appropriate network for segmentation
+
+        Args:
+            interface: Interface instance
+            images: list of images
+
+        Returns:
+            list of masks
+        """
+        scene_analysis = self.scene_classifier(images)
+        images_per_scene = {}
+        for i, image in enumerate(images):
+            scene_name = scene_analysis[i][0][0]
+            if scene_name not in images_per_scene:
+                images_per_scene[scene_name] = []
+            images_per_scene[scene_name].append(image)
+
+        masks_per_scene = {}
+        for scene_name, igs in list(images_per_scene.items()):
+            net = self.select_net(scene_name)
+            if isinstance(interface.segmentation_pipeline, net):
+                masks_per_scene[scene_name] = interface.segmentation_pipeline(igs)
+            else:
+                old_device = interface.segmentation_pipeline.device
+                interface.segmentation_pipeline.to(
+                    "cpu"
+                )  # unload model from gpu, to avoid OOM
+                net_instance = net(device=old_device)
+                masks_per_scene[scene_name] = net_instance(igs)
+                del net_instance
+                interface.segmentation_pipeline.to(old_device)  # load model back to gpu
+
+        # restore one list of masks with the same order as images
+        masks = []
+        for i, image in enumerate(images):
+            scene_name = scene_analysis[i][0][0]
+            masks.append(
+                masks_per_scene[scene_name][images_per_scene[scene_name].index(image)]
+            )
+
+        return masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class AutoScene +(scene_classifier:Β SceneClassifier) +
+
+

AutoScene preprocessing method

+

Args

+
+
scene_classifier
+
SceneClassifier instance
+
+
+ +Expand source code + +
class AutoScene:
+    """AutoScene preprocessing method"""
+
+    def __init__(self, scene_classifier: SceneClassifier):
+        """
+        Args:
+            scene_classifier: SceneClassifier instance
+        """
+        self.scene_classifier = scene_classifier
+
+    @staticmethod
+    def select_net(scene: str):
+        """
+        Selects the network to be used for segmentation based on the detected scene
+
+        Args:
+            scene: scene name
+        """
+        if scene == "hard":
+            return TracerUniversalB7
+        elif scene == "soft":
+            return U2NET
+        elif scene == "digital":
+            return TracerUniversalB7  # TODO: not implemented yet
+
+    def __call__(self, interface, images: List[Union[str, Path, Image.Image]]):
+        """
+        Automatically detects the scene and selects the appropriate network for segmentation
+
+        Args:
+            interface: Interface instance
+            images: list of images
+
+        Returns:
+            list of masks
+        """
+        scene_analysis = self.scene_classifier(images)
+        images_per_scene = {}
+        for i, image in enumerate(images):
+            scene_name = scene_analysis[i][0][0]
+            if scene_name not in images_per_scene:
+                images_per_scene[scene_name] = []
+            images_per_scene[scene_name].append(image)
+
+        masks_per_scene = {}
+        for scene_name, igs in list(images_per_scene.items()):
+            net = self.select_net(scene_name)
+            if isinstance(interface.segmentation_pipeline, net):
+                masks_per_scene[scene_name] = interface.segmentation_pipeline(igs)
+            else:
+                old_device = interface.segmentation_pipeline.device
+                interface.segmentation_pipeline.to(
+                    "cpu"
+                )  # unload model from gpu, to avoid OOM
+                net_instance = net(device=old_device)
+                masks_per_scene[scene_name] = net_instance(igs)
+                del net_instance
+                interface.segmentation_pipeline.to(old_device)  # load model back to gpu
+
+        # restore one list of masks with the same order as images
+        masks = []
+        for i, image in enumerate(images):
+            scene_name = scene_analysis[i][0][0]
+            masks.append(
+                masks_per_scene[scene_name][images_per_scene[scene_name].index(image)]
+            )
+
+        return masks
+
+

Static methods

+
+
+def select_net(scene:Β str) +
+
+

Selects the network to be used for segmentation based on the detected scene

+

Args

+
+
scene
+
scene name
+
+
+ +Expand source code + +
@staticmethod
+def select_net(scene: str):
+    """
+    Selects the network to be used for segmentation based on the detected scene
+
+    Args:
+        scene: scene name
+    """
+    if scene == "hard":
+        return TracerUniversalB7
+    elif scene == "soft":
+        return U2NET
+    elif scene == "digital":
+        return TracerUniversalB7  # TODO: not implemented yet
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/pipelines/preprocessing/index.html b/docs/api/carvekit/pipelines/preprocessing/index.html new file mode 100644 index 0000000..2c01a4a --- /dev/null +++ b/docs/api/carvekit/pipelines/preprocessing/index.html @@ -0,0 +1,81 @@ + + + + + + +carvekit.pipelines.preprocessing API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.pipelines.preprocessing

+
+
+
+ +Expand source code + +
from carvekit.pipelines.preprocessing.stub import PreprocessingStub
+from carvekit.pipelines.preprocessing.autoscene import AutoScene
+
+
+
+

Sub-modules

+
+
carvekit.pipelines.preprocessing.autoscene
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
carvekit.pipelines.preprocessing.stub
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/pipelines/preprocessing/stub.html b/docs/api/carvekit/pipelines/preprocessing/stub.html new file mode 100644 index 0000000..db50267 --- /dev/null +++ b/docs/api/carvekit/pipelines/preprocessing/stub.html @@ -0,0 +1,127 @@ + + + + + + +carvekit.pipelines.preprocessing.stub API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.pipelines.preprocessing.stub

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from pathlib import Path
+from typing import Union, List
+
+from PIL import Image
+
+__all__ = ["PreprocessingStub"]
+
+
+class PreprocessingStub:
+    """Stub for future preprocessing methods"""
+
+    def __call__(self, interface, images: List[Union[str, Path, Image.Image]]):
+        """
+        Passes data though `interface.segmentation_pipeline()` method
+
+        Args:
+        - `interface`: Interface instance
+        - `images`: list of images
+
+        Returns:
+            the result of passing data through segmentation_pipeline method of interface
+        """
+        return interface.segmentation_pipeline(images=images)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class PreprocessingStub +
+
+

Stub for future preprocessing methods

+
+ +Expand source code + +
class PreprocessingStub:
+    """Stub for future preprocessing methods"""
+
+    def __call__(self, interface, images: List[Union[str, Path, Image.Image]]):
+        """
+        Passes data though `interface.segmentation_pipeline()` method
+
+        Args:
+        - `interface`: Interface instance
+        - `images`: list of images
+
+        Returns:
+            the result of passing data through segmentation_pipeline method of interface
+        """
+        return interface.segmentation_pipeline(images=images)
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/trimap/add_ops.html b/docs/api/carvekit/trimap/add_ops.html new file mode 100644 index 0000000..24e3625 --- /dev/null +++ b/docs/api/carvekit/trimap/add_ops.html @@ -0,0 +1,321 @@ + + + + + + +carvekit.trimap.add_ops API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.trimap.add_ops

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import cv2
+import numpy as np
+from PIL import Image
+
+
+def prob_filter(mask: Image.Image, prob_threshold=231) -> Image.Image:
+    """
+    Applies a filter to the mask by the probability of locating an object in the object area.
+
+    Args:
+        prob_threshold (int, default=231): Threshold of probability for mark area as background.
+        mask (Image.Image): Predicted object mask
+
+    Raises:
+        ValueError: if mask or trimap has wrong color mode
+
+    Returns:
+        Image.Image: generated trimap for image.
+    """
+    if mask.mode != "L":
+        raise ValueError("Input mask has wrong color mode.")
+    # noinspection PyTypeChecker
+    mask_array = np.array(mask)
+    mask_array[mask_array > prob_threshold] = 255  # Probability filter for mask
+    mask_array[mask_array <= prob_threshold] = 0
+    return Image.fromarray(mask_array).convert("L")
+
+
+def prob_as_unknown_area(
+    trimap: Image.Image, mask: Image.Image, prob_threshold=255
+) -> Image.Image:
+    """
+    Marks any uncertainty in the seg mask as an unknown region.
+
+    Args:
+        prob_threshold (int, default=255): Threshold of probability for mark area as unknown.
+        trimap (Image.Image): Generated trimap.
+        mask (Image.Image): Predicted object mask
+
+    Raises:
+        ValueError: if mask or trimap has wrong color mode
+
+    Returns:
+        Image.Image: Generated trimap for image.
+    """
+    if mask.mode != "L" or trimap.mode != "L":
+        raise ValueError("Input mask has wrong color mode.")
+    # noinspection PyTypeChecker
+    mask_array = np.array(mask)
+    # noinspection PyTypeChecker
+    trimap_array = np.array(trimap)
+    trimap_array[np.logical_and(mask_array <= prob_threshold, mask_array > 0)] = 127
+    return Image.fromarray(trimap_array).convert("L")
+
+
+def post_erosion(trimap: Image.Image, erosion_iters=1) -> Image.Image:
+    """
+    Performs erosion on the mask and marks the resulting area as an unknown region.
+
+    Args:
+        erosion_iters (int, default=1): The number of iterations of erosion that
+        the object's mask will be subjected to before forming an unknown area
+        trimap (Image.Image): Generated trimap.
+
+    Returns:
+        Image.Image: Generated trimap for image.
+    """
+    if trimap.mode != "L":
+        raise ValueError("Input mask has wrong color mode.")
+    # noinspection PyTypeChecker
+    trimap_array = np.array(trimap)
+    if erosion_iters > 0:
+        without_unknown_area = trimap_array.copy()
+        without_unknown_area[without_unknown_area == 127] = 0
+
+        erosion_kernel = np.ones((3, 3), np.uint8)
+        erode = cv2.erode(
+            without_unknown_area, erosion_kernel, iterations=erosion_iters
+        )
+        erode = np.where(erode == 0, 0, without_unknown_area)
+        trimap_array[np.logical_and(erode == 0, without_unknown_area > 0)] = 127
+        erode = trimap_array.copy()
+    else:
+        erode = trimap_array.copy()
+    return Image.fromarray(erode).convert("L")
+
+
+
+
+
+
+
+

Functions

+
+
+def post_erosion(trimap:Β PIL.Image.Image, erosion_iters=1) ‑>Β PIL.Image.Image +
+
+

Performs erosion on the mask and marks the resulting area as an unknown region.

+

Args

+
+
erosion_iters : int, default=1
+
The number of iterations of erosion that
+
the object's mask will be subjected to before forming an unknown area
+
trimap : Image.Image
+
Generated trimap.
+
+

Returns

+
+
Image.Image
+
Generated trimap for image.
+
+
+ +Expand source code + +
def post_erosion(trimap: Image.Image, erosion_iters=1) -> Image.Image:
+    """
+    Performs erosion on the mask and marks the resulting area as an unknown region.
+
+    Args:
+        erosion_iters (int, default=1): The number of iterations of erosion that
+        the object's mask will be subjected to before forming an unknown area
+        trimap (Image.Image): Generated trimap.
+
+    Returns:
+        Image.Image: Generated trimap for image.
+    """
+    if trimap.mode != "L":
+        raise ValueError("Input mask has wrong color mode.")
+    # noinspection PyTypeChecker
+    trimap_array = np.array(trimap)
+    if erosion_iters > 0:
+        without_unknown_area = trimap_array.copy()
+        without_unknown_area[without_unknown_area == 127] = 0
+
+        erosion_kernel = np.ones((3, 3), np.uint8)
+        erode = cv2.erode(
+            without_unknown_area, erosion_kernel, iterations=erosion_iters
+        )
+        erode = np.where(erode == 0, 0, without_unknown_area)
+        trimap_array[np.logical_and(erode == 0, without_unknown_area > 0)] = 127
+        erode = trimap_array.copy()
+    else:
+        erode = trimap_array.copy()
+    return Image.fromarray(erode).convert("L")
+
+
+
+def prob_as_unknown_area(trimap:Β PIL.Image.Image, mask:Β PIL.Image.Image, prob_threshold=255) ‑>Β PIL.Image.Image +
+
+

Marks any uncertainty in the seg mask as an unknown region.

+

Args

+
+
prob_threshold : int, default=255
+
Threshold of probability for mark area as unknown.
+
trimap : Image.Image
+
Generated trimap.
+
mask : Image.Image
+
Predicted object mask
+
+

Raises

+
+
ValueError
+
if mask or trimap has wrong color mode
+
+

Returns

+
+
Image.Image
+
Generated trimap for image.
+
+
+ +Expand source code + +
def prob_as_unknown_area(
+    trimap: Image.Image, mask: Image.Image, prob_threshold=255
+) -> Image.Image:
+    """
+    Marks any uncertainty in the seg mask as an unknown region.
+
+    Args:
+        prob_threshold (int, default=255): Threshold of probability for mark area as unknown.
+        trimap (Image.Image): Generated trimap.
+        mask (Image.Image): Predicted object mask
+
+    Raises:
+        ValueError: if mask or trimap has wrong color mode
+
+    Returns:
+        Image.Image: Generated trimap for image.
+    """
+    if mask.mode != "L" or trimap.mode != "L":
+        raise ValueError("Input mask has wrong color mode.")
+    # noinspection PyTypeChecker
+    mask_array = np.array(mask)
+    # noinspection PyTypeChecker
+    trimap_array = np.array(trimap)
+    trimap_array[np.logical_and(mask_array <= prob_threshold, mask_array > 0)] = 127
+    return Image.fromarray(trimap_array).convert("L")
+
+
+
+def prob_filter(mask:Β PIL.Image.Image, prob_threshold=231) ‑>Β PIL.Image.Image +
+
+

Applies a filter to the mask by the probability of locating an object in the object area.

+

Args

+
+
prob_threshold : int, default=231
+
Threshold of probability for mark area as background.
+
mask : Image.Image
+
Predicted object mask
+
+

Raises

+
+
ValueError
+
if mask or trimap has wrong color mode
+
+

Returns

+
+
Image.Image
+
generated trimap for image.
+
+
+ +Expand source code + +
def prob_filter(mask: Image.Image, prob_threshold=231) -> Image.Image:
+    """
+    Applies a filter to the mask by the probability of locating an object in the object area.
+
+    Args:
+        prob_threshold (int, default=231): Threshold of probability for mark area as background.
+        mask (Image.Image): Predicted object mask
+
+    Raises:
+        ValueError: if mask or trimap has wrong color mode
+
+    Returns:
+        Image.Image: generated trimap for image.
+    """
+    if mask.mode != "L":
+        raise ValueError("Input mask has wrong color mode.")
+    # noinspection PyTypeChecker
+    mask_array = np.array(mask)
+    mask_array[mask_array > prob_threshold] = 255  # Probability filter for mask
+    mask_array[mask_array <= prob_threshold] = 0
+    return Image.fromarray(mask_array).convert("L")
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/trimap/cv_gen.html b/docs/api/carvekit/trimap/cv_gen.html new file mode 100644 index 0000000..dbcaf39 --- /dev/null +++ b/docs/api/carvekit/trimap/cv_gen.html @@ -0,0 +1,215 @@ + + + + + + +carvekit.trimap.cv_gen API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.trimap.cv_gen

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import PIL.Image
+import cv2
+import numpy as np
+
+
+class CV2TrimapGenerator:
+    def __init__(self, kernel_size: int = 30, erosion_iters: int = 1):
+        """
+        Initialize a new CV2TrimapGenerator instance
+
+        Args:
+            kernel_size (int, default=30): The size of the offset from the object mask
+            in pixels when an unknown area is detected in the trimap
+            erosion_iters (int, default=1: The number of iterations of erosion that
+            the object's mask will be subjected to before forming an unknown area
+        """
+        self.kernel_size = kernel_size
+        self.erosion_iters = erosion_iters
+
+    def __call__(
+        self, original_image: PIL.Image.Image, mask: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Generates trimap based on predicted object mask to refine object mask borders.
+        Based on cv2 erosion algorithm.
+
+        Args:
+            original_image (PIL.Image.Image): Original image
+            mask (PIL.Image.Image): Predicted object mask
+
+        Returns:
+            PIL.Image.Image: Generated trimap for image.
+        """
+        if mask.mode != "L":
+            raise ValueError("Input mask has wrong color mode.")
+        if mask.size != original_image.size:
+            raise ValueError("Sizes of input image and predicted mask doesn't equal")
+        # noinspection PyTypeChecker
+        mask_array = np.array(mask)
+        pixels = 2 * self.kernel_size + 1
+        kernel = np.ones((pixels, pixels), np.uint8)
+
+        if self.erosion_iters > 0:
+            erosion_kernel = np.ones((3, 3), np.uint8)
+            erode = cv2.erode(mask_array, erosion_kernel, iterations=self.erosion_iters)
+            erode = np.where(erode == 0, 0, mask_array)
+        else:
+            erode = mask_array.copy()
+
+        dilation = cv2.dilate(erode, kernel, iterations=1)
+
+        dilation = np.where(dilation == 255, 127, dilation)  # WHITE to GRAY
+        trimap = np.where(erode > 127, 200, dilation)  # mark the tumor inside GRAY
+
+        trimap = np.where(trimap < 127, 0, trimap)  # Embelishment
+        trimap = np.where(trimap > 200, 0, trimap)  # Embelishment
+        trimap = np.where(trimap == 200, 255, trimap)  # GRAY to WHITE
+
+        return PIL.Image.fromarray(trimap).convert("L")
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class CV2TrimapGenerator +(kernel_size:Β intΒ =Β 30, erosion_iters:Β intΒ =Β 1) +
+
+

Initialize a new CV2TrimapGenerator instance

+

Args

+
+
kernel_size : int, default=30
+
The size of the offset from the object mask
+
+

in pixels when an unknown area is detected in the trimap +erosion_iters (int, default=1: The number of iterations of erosion that +the object's mask will be subjected to before forming an unknown area

+
+ +Expand source code + +
class CV2TrimapGenerator:
+    def __init__(self, kernel_size: int = 30, erosion_iters: int = 1):
+        """
+        Initialize a new CV2TrimapGenerator instance
+
+        Args:
+            kernel_size (int, default=30): The size of the offset from the object mask
+            in pixels when an unknown area is detected in the trimap
+            erosion_iters (int, default=1: The number of iterations of erosion that
+            the object's mask will be subjected to before forming an unknown area
+        """
+        self.kernel_size = kernel_size
+        self.erosion_iters = erosion_iters
+
+    def __call__(
+        self, original_image: PIL.Image.Image, mask: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Generates trimap based on predicted object mask to refine object mask borders.
+        Based on cv2 erosion algorithm.
+
+        Args:
+            original_image (PIL.Image.Image): Original image
+            mask (PIL.Image.Image): Predicted object mask
+
+        Returns:
+            PIL.Image.Image: Generated trimap for image.
+        """
+        if mask.mode != "L":
+            raise ValueError("Input mask has wrong color mode.")
+        if mask.size != original_image.size:
+            raise ValueError("Sizes of input image and predicted mask doesn't equal")
+        # noinspection PyTypeChecker
+        mask_array = np.array(mask)
+        pixels = 2 * self.kernel_size + 1
+        kernel = np.ones((pixels, pixels), np.uint8)
+
+        if self.erosion_iters > 0:
+            erosion_kernel = np.ones((3, 3), np.uint8)
+            erode = cv2.erode(mask_array, erosion_kernel, iterations=self.erosion_iters)
+            erode = np.where(erode == 0, 0, mask_array)
+        else:
+            erode = mask_array.copy()
+
+        dilation = cv2.dilate(erode, kernel, iterations=1)
+
+        dilation = np.where(dilation == 255, 127, dilation)  # WHITE to GRAY
+        trimap = np.where(erode > 127, 200, dilation)  # mark the tumor inside GRAY
+
+        trimap = np.where(trimap < 127, 0, trimap)  # Embelishment
+        trimap = np.where(trimap > 200, 0, trimap)  # Embelishment
+        trimap = np.where(trimap == 200, 255, trimap)  # GRAY to WHITE
+
+        return PIL.Image.fromarray(trimap).convert("L")
+
+

Subclasses

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/trimap/generator.html b/docs/api/carvekit/trimap/generator.html new file mode 100644 index 0000000..f241a2f --- /dev/null +++ b/docs/api/carvekit/trimap/generator.html @@ -0,0 +1,187 @@ + + + + + + +carvekit.trimap.generator API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.trimap.generator

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from PIL import Image
+from carvekit.trimap.cv_gen import CV2TrimapGenerator
+from carvekit.trimap.add_ops import prob_filter, prob_as_unknown_area, post_erosion
+
+
+class TrimapGenerator(CV2TrimapGenerator):
+    def __init__(
+        self, prob_threshold: int = 231, kernel_size: int = 30, erosion_iters: int = 5
+    ):
+        """
+        Initialize a TrimapGenerator instance
+
+        Args:
+            prob_threshold (int, default=231): Probability threshold at which the
+            prob_filter and prob_as_unknown_area operations will be applied
+            kernel_size (int, default=30): The size of the offset from the object mask
+            in pixels when an unknown area is detected in the trimap
+            erosion_iters (int, default=5): The number of iterations of erosion that
+            the object's mask will be subjected to before forming an unknown area
+        """
+        super().__init__(kernel_size, erosion_iters=0)
+        self.prob_threshold = prob_threshold
+        self.__erosion_iters = erosion_iters
+
+    def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Image:
+        """
+        Generates trimap based on predicted object mask to refine object mask borders.
+        Based on cv2 erosion algorithm and additional prob. filters.
+
+        Args:
+            original_image (Image.Image): Original image
+            mask (Image.Image): Predicted object mask
+
+        Returns:
+            Image.Image: Generated trimap for image.
+        """
+        filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold)
+        trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask)
+        new_trimap = prob_as_unknown_area(
+            trimap=trimap, mask=mask, prob_threshold=self.prob_threshold
+        )
+        new_trimap = post_erosion(new_trimap, self.__erosion_iters)
+        return new_trimap
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class TrimapGenerator +(prob_threshold:Β intΒ =Β 231, kernel_size:Β intΒ =Β 30, erosion_iters:Β intΒ =Β 5) +
+
+

Initialize a TrimapGenerator instance

+

Args

+
+
prob_threshold : int, default=231
+
Probability threshold at which the
+
prob_filter and prob_as_unknown_area operations will be applied
+
kernel_size : int, default=30
+
The size of the offset from the object mask
+
in pixels when an unknown area is detected in the trimap
+
erosion_iters : int, default=5
+
The number of iterations of erosion that
+
+

the object's mask will be subjected to before forming an unknown area

+
+ +Expand source code + +
class TrimapGenerator(CV2TrimapGenerator):
+    def __init__(
+        self, prob_threshold: int = 231, kernel_size: int = 30, erosion_iters: int = 5
+    ):
+        """
+        Initialize a TrimapGenerator instance
+
+        Args:
+            prob_threshold (int, default=231): Probability threshold at which the
+            prob_filter and prob_as_unknown_area operations will be applied
+            kernel_size (int, default=30): The size of the offset from the object mask
+            in pixels when an unknown area is detected in the trimap
+            erosion_iters (int, default=5): The number of iterations of erosion that
+            the object's mask will be subjected to before forming an unknown area
+        """
+        super().__init__(kernel_size, erosion_iters=0)
+        self.prob_threshold = prob_threshold
+        self.__erosion_iters = erosion_iters
+
+    def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Image:
+        """
+        Generates trimap based on predicted object mask to refine object mask borders.
+        Based on cv2 erosion algorithm and additional prob. filters.
+
+        Args:
+            original_image (Image.Image): Original image
+            mask (Image.Image): Predicted object mask
+
+        Returns:
+            Image.Image: Generated trimap for image.
+        """
+        filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold)
+        trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask)
+        new_trimap = prob_as_unknown_area(
+            trimap=trimap, mask=mask, prob_threshold=self.prob_threshold
+        )
+        new_trimap = post_erosion(new_trimap, self.__erosion_iters)
+        return new_trimap
+
+

Ancestors

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/trimap/index.html b/docs/api/carvekit/trimap/index.html new file mode 100644 index 0000000..315ef68 --- /dev/null +++ b/docs/api/carvekit/trimap/index.html @@ -0,0 +1,81 @@ + + + + + + +carvekit.trimap API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.trimap

+
+
+
+
+

Sub-modules

+
+
carvekit.trimap.add_ops
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
carvekit.trimap.cv_gen
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
carvekit.trimap.generator
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/utils/download_models.html b/docs/api/carvekit/utils/download_models.html new file mode 100644 index 0000000..67b6c2d --- /dev/null +++ b/docs/api/carvekit/utils/download_models.html @@ -0,0 +1,780 @@ + + + + + + +carvekit.utils.download_models API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.utils.download_models

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import hashlib
+import os
+import warnings
+from abc import ABCMeta, abstractmethod, ABC
+from pathlib import Path
+from typing import Optional
+
+import carvekit
+from carvekit.ml.files import checkpoints_dir
+
+import requests
+import tqdm
+
+requests = requests.Session()
+requests.headers.update({"User-Agent": f"Carvekit/{carvekit.version}"})
+
+MODELS_URLS = {
+    "basnet.pth": {
+        "repository": "Carve/basnet-universal",
+        "revision": "870becbdb364fda6d8fdb2c10b072542f8d08701",
+        "filename": "basnet.pth",
+    },
+    "deeplab.pth": {
+        "repository": "Carve/deeplabv3-resnet101",
+        "revision": "d504005392fc877565afdf58aad0cd524682d2b0",
+        "filename": "deeplab.pth",
+    },
+    "fba_matting.pth": {
+        "repository": "Carve/fba",
+        "revision": "a5d3457df0fb9c88ea19ed700d409756ca2069d1",
+        "filename": "fba_matting.pth",
+    },
+    "u2net.pth": {
+        "repository": "Carve/u2net-universal",
+        "revision": "10305d785481cf4b2eee1d447c39cd6e5f43d74b",
+        "filename": "full_weights.pth",
+    },
+    "tracer_b7.pth": {
+        "repository": "Carve/tracer_b7",
+        "revision": "d8a8fd9e7b3fa0d2f1506fe7242966b34381e9c5",
+        "filename": "tracer_b7.pth",
+    },
+    "scene_classifier.pth": {
+        "repository": "Carve/scene_classifier",
+        "revision": "71c8e4c771dd5a20ff0c5c9e3c8f1c9cf8082740",
+        "filename": "scene_classifier.pth",
+    },
+    "yolov4_coco_with_classes.pth": {
+        "repository": "Carve/yolov4_coco",
+        "revision": "e3fc9cd22f86e456d2749d1ae148400f2f950fb3",
+        "filename": "yolov4_coco_with_classes.pth",
+    },
+    "cascadepsp.pth": {
+        "repository": "Carve/cascadepsp",
+        "revision": "3ca1e5e432344b1277bc88d1c6d4265c46cff62f",
+        "filename": "cascadepsp.pth",
+    },
+}
+"""
+All data needed to build path relative to huggingface.co for model download
+"""
+
+MODELS_CHECKSUMS = {
+    "basnet.pth": "e409cb709f4abca87cb11bd44a9ad3f909044a917977ab65244b4c94dd33"
+    "8b1a37755c4253d7cb54526b7763622a094d7b676d34b5e6886689256754e5a5e6ad",
+    "deeplab.pth": "9c5a1795bc8baa267200a44b49ac544a1ba2687d210f63777e4bd715387324469a59b072f8a28"
+    "9cc471c637b367932177e5b312e8ea6351c1763d9ff44b4857c",
+    "fba_matting.pth": "890906ec94c1bfd2ad08707a63e4ccb0955d7f5d25e32853950c24c78"
+    "4cbad2e59be277999defc3754905d0f15aa75702cdead3cfe669ff72f08811c52971613",
+    "u2net.pth": "16f8125e2fedd8c85db0e001ee15338b4aa2fda77bab8ba70c25e"
+    "bea1533fda5ee70a909b934a9bd495b432cef89d629f00a07858a517742476fa8b346de24f7",
+    "tracer_b7.pth": "c439c5c12d4d43d5f9be9ec61e68b2e54658a541bccac2577ef5a54fb252b6e8415d41f7e"
+    "c2487033d0c02b4dd08367958e4e62091318111c519f93e2632be7b",
+    "scene_classifier.pth": "6d8692510abde453b406a1fea557afdea62fd2a2a2677283a3ecc2"
+    "341a4895ee99ed65cedcb79b80775db14c3ffcfc0aad2caec1d85140678852039d2d4e76b4",
+    "yolov4_coco_with_classes.pth": "44b6ec2dd35dc3802bf8c512002f76e00e97bfbc86bc7af6de2fafce229a41b4ca"
+    "12c6f3d7589278c71cd4ddd62df80389b148c19b84fa03216905407a107fff",
+    "cascadepsp.pth": "3f895f5126d80d6f73186f045557ea7c8eab4dfa3d69a995815bb2c03d564573f36c474f04d7bf0022a27829f583a1a793b036adf801cb423e41a4831b830122",
+}
+"""
+Model -> checksum dictionary
+"""
+
+
+def sha512_checksum_calc(file: Path) -> str:
+    """
+    Calculates the SHA512 hash digest of a file on fs
+
+    Args:
+        file (Path): Path to the file
+
+    Returns:
+        SHA512 hash digest of a file.
+    """
+    dd = hashlib.sha512()
+    with file.open("rb") as f:
+        for chunk in iter(lambda: f.read(4096), b""):
+            dd.update(chunk)
+    return dd.hexdigest()
+
+
+class CachedDownloader:
+    """
+    Metaclass for models downloaders.
+    """
+
+    __metaclass__ = ABCMeta
+
+    @property
+    @abstractmethod
+    def name(self) -> str:
+        return self.__class__.__name__
+
+    @property
+    @abstractmethod
+    def fallback_downloader(self) -> Optional["CachedDownloader"]:
+        """
+        Property MAY be overriden in subclasses.
+        Used in case if subclass failed to download model. So preferred downloader SHOULD be placed higher in the hierarchy.
+        Less preferred downloader SHOULD be provided by this property.
+        """
+        pass
+
+    def download_model(self, file_name: str) -> Path:
+        """
+        Downloads model from the internet and saves it to the cache.
+
+        Behavior:
+            If model is already downloaded it will be loaded from the cache.
+
+            If model is already downloaded, but checksum is invalid, it will be downloaded again.
+
+            If model download failed, fallback downloader will be used.
+        """
+        try:
+            return self.download_model_base(file_name)
+        except BaseException as e:
+            if self.fallback_downloader is not None:
+                warnings.warn(
+                    f"Failed to download model from {self.name} downloader."
+                    f" Trying to download from {self.fallback_downloader.name} downloader."
+                )
+                return self.fallback_downloader.download_model(file_name)
+            else:
+                warnings.warn(
+                    f"Failed to download model from {self.name} downloader."
+                    f" No fallback downloader available."
+                )
+                raise e
+
+    @abstractmethod
+    def download_model_base(self, model_name: str) -> Path:
+        """
+        Download model from any source if not cached.
+        Returns:
+            pathlib.Path: Path to the downloaded model.
+        """
+
+    def __call__(self, model_name: str):
+        return self.download_model(model_name)
+
+
+class HuggingFaceCompatibleDownloader(CachedDownloader, ABC):
+    """
+    Downloader for models from HuggingFace Hub.
+    Private models are not supported.
+    """
+
+    def __init__(
+        self,
+        name: str = "Huggingface.co",
+        base_url: str = "https://huggingface.co",
+        fb_downloader: Optional["CachedDownloader"] = None,
+    ):
+        self.cache_dir = checkpoints_dir
+        """SHOULD be same for all instances to prevent downloading same model multiple times
+        Points to ~/.cache/carvekit/checkpoints"""
+        self.base_url = base_url
+        """MUST be a base url with protocol and domain name to huggingface or another, compatible in terms of models downloading API source"""
+        self._name = name
+        self._fallback_downloader = fb_downloader
+
+    @property
+    def fallback_downloader(self) -> Optional["CachedDownloader"]:
+        return self._fallback_downloader
+
+    @property
+    def name(self):
+        return self._name
+
+    def check_for_existence(self, model_name: str) -> Optional[Path]:
+        """
+        Checks if model is already downloaded and cached. Verifies file integrity by checksum.
+        Returns:
+            Optional[pathlib.Path]: Path to the cached model if cached.
+        """
+        if model_name not in MODELS_URLS.keys():
+            raise FileNotFoundError("Unknown model!")
+        path = (
+            self.cache_dir
+            / MODELS_URLS[model_name]["repository"].split("/")[1]
+            / model_name
+        )
+
+        if not path.exists():
+            return None
+
+        if MODELS_CHECKSUMS[path.name] != sha512_checksum_calc(path):
+            warnings.warn(
+                f"Invalid checksum for model {path.name}. Downloading correct model!"
+            )
+            os.remove(path)
+            return None
+        return path
+
+    def download_model_base(self, model_name: str) -> Path:
+        cached_path = self.check_for_existence(model_name)
+        if cached_path is not None:
+            return cached_path
+        else:
+            cached_path = (
+                self.cache_dir
+                / MODELS_URLS[model_name]["repository"].split("/")[1]
+                / model_name
+            )
+            cached_path.parent.mkdir(parents=True, exist_ok=True)
+            url = MODELS_URLS[model_name]
+            hugging_face_url = f"{self.base_url}/{url['repository']}/resolve/{url['revision']}/{url['filename']}"
+
+            try:
+                r = requests.get(hugging_face_url, stream=True, timeout=10)
+                if r.status_code < 400:
+                    with open(cached_path, "wb") as f:
+                        r.raw.decode_content = True
+                        for chunk in tqdm.tqdm(
+                            r,
+                            desc="Downloading " + cached_path.name + " model",
+                            colour="blue",
+                        ):
+                            f.write(chunk)
+                else:
+                    if r.status_code == 404:
+                        raise FileNotFoundError(f"Model {model_name} not found!")
+                    else:
+                        raise ConnectionError(
+                            f"Error {r.status_code} while downloading model {model_name}!"
+                        )
+            except BaseException as e:
+                if cached_path.exists():
+                    os.remove(cached_path)
+                raise ConnectionError(
+                    f"Exception caught when downloading model! "
+                    f"Model name: {cached_path.name}. Exception: {str(e)}."
+                )
+            return cached_path
+
+
+fallback_downloader: CachedDownloader = HuggingFaceCompatibleDownloader()
+downloader: CachedDownloader = HuggingFaceCompatibleDownloader(
+    base_url="https://cdn.carve.photos",
+    fb_downloader=fallback_downloader,
+    name="Carve CDN",
+)
+downloader._fallback_downloader = fallback_downloader
+
+
+
+
+
+

Global variables

+
+
var MODELS_CHECKSUMS
+
+

Model -> checksum dictionary

+
+
var MODELS_URLS
+
+

All data needed to build path relative to huggingface.co for model download

+
+
+
+
+

Functions

+
+
+def sha512_checksum_calc(file:Β pathlib.Path) ‑>Β str +
+
+

Calculates the SHA512 hash digest of a file on fs

+

Args

+
+
file : Path
+
Path to the file
+
+

Returns

+

SHA512 hash digest of a file.

+
+ +Expand source code + +
def sha512_checksum_calc(file: Path) -> str:
+    """
+    Calculates the SHA512 hash digest of a file on fs
+
+    Args:
+        file (Path): Path to the file
+
+    Returns:
+        SHA512 hash digest of a file.
+    """
+    dd = hashlib.sha512()
+    with file.open("rb") as f:
+        for chunk in iter(lambda: f.read(4096), b""):
+            dd.update(chunk)
+    return dd.hexdigest()
+
+
+
+
+
+

Classes

+
+
+class CachedDownloader +
+
+

Metaclass for models downloaders.

+
+ +Expand source code + +
class CachedDownloader:
+    """
+    Metaclass for models downloaders.
+    """
+
+    __metaclass__ = ABCMeta
+
+    @property
+    @abstractmethod
+    def name(self) -> str:
+        return self.__class__.__name__
+
+    @property
+    @abstractmethod
+    def fallback_downloader(self) -> Optional["CachedDownloader"]:
+        """
+        Property MAY be overriden in subclasses.
+        Used in case if subclass failed to download model. So preferred downloader SHOULD be placed higher in the hierarchy.
+        Less preferred downloader SHOULD be provided by this property.
+        """
+        pass
+
+    def download_model(self, file_name: str) -> Path:
+        """
+        Downloads model from the internet and saves it to the cache.
+
+        Behavior:
+            If model is already downloaded it will be loaded from the cache.
+
+            If model is already downloaded, but checksum is invalid, it will be downloaded again.
+
+            If model download failed, fallback downloader will be used.
+        """
+        try:
+            return self.download_model_base(file_name)
+        except BaseException as e:
+            if self.fallback_downloader is not None:
+                warnings.warn(
+                    f"Failed to download model from {self.name} downloader."
+                    f" Trying to download from {self.fallback_downloader.name} downloader."
+                )
+                return self.fallback_downloader.download_model(file_name)
+            else:
+                warnings.warn(
+                    f"Failed to download model from {self.name} downloader."
+                    f" No fallback downloader available."
+                )
+                raise e
+
+    @abstractmethod
+    def download_model_base(self, model_name: str) -> Path:
+        """
+        Download model from any source if not cached.
+        Returns:
+            pathlib.Path: Path to the downloaded model.
+        """
+
+    def __call__(self, model_name: str):
+        return self.download_model(model_name)
+
+

Subclasses

+ +

Instance variables

+
+
var fallback_downloader :Β Optional[CachedDownloader]
+
+

Property MAY be overriden in subclasses. +Used in case if subclass failed to download model. So preferred downloader SHOULD be placed higher in the hierarchy. +Less preferred downloader SHOULD be provided by this property.

+
+ +Expand source code + +
@property
+@abstractmethod
+def fallback_downloader(self) -> Optional["CachedDownloader"]:
+    """
+    Property MAY be overriden in subclasses.
+    Used in case if subclass failed to download model. So preferred downloader SHOULD be placed higher in the hierarchy.
+    Less preferred downloader SHOULD be provided by this property.
+    """
+    pass
+
+
+
var name :Β str
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def name(self) -> str:
+    return self.__class__.__name__
+
+
+
+

Methods

+
+
+def download_model(self, file_name:Β str) ‑>Β pathlib.Path +
+
+

Downloads model from the internet and saves it to the cache.

+

Behavior

+

If model is already downloaded it will be loaded from the cache.

+

If model is already downloaded, but checksum is invalid, it will be downloaded again.

+

If model download failed, fallback downloader will be used.

+
+ +Expand source code + +
def download_model(self, file_name: str) -> Path:
+    """
+    Downloads model from the internet and saves it to the cache.
+
+    Behavior:
+        If model is already downloaded it will be loaded from the cache.
+
+        If model is already downloaded, but checksum is invalid, it will be downloaded again.
+
+        If model download failed, fallback downloader will be used.
+    """
+    try:
+        return self.download_model_base(file_name)
+    except BaseException as e:
+        if self.fallback_downloader is not None:
+            warnings.warn(
+                f"Failed to download model from {self.name} downloader."
+                f" Trying to download from {self.fallback_downloader.name} downloader."
+            )
+            return self.fallback_downloader.download_model(file_name)
+        else:
+            warnings.warn(
+                f"Failed to download model from {self.name} downloader."
+                f" No fallback downloader available."
+            )
+            raise e
+
+
+
+def download_model_base(self, model_name:Β str) ‑>Β pathlib.Path +
+
+

Download model from any source if not cached.

+

Returns

+
+
pathlib.Path
+
Path to the downloaded model.
+
+
+ +Expand source code + +
@abstractmethod
+def download_model_base(self, model_name: str) -> Path:
+    """
+    Download model from any source if not cached.
+    Returns:
+        pathlib.Path: Path to the downloaded model.
+    """
+
+
+
+
+
+class HuggingFaceCompatibleDownloader +(name:Β strΒ =Β 'Huggingface.co', base_url:Β strΒ =Β 'https://huggingface.co', fb_downloader:Β Optional[ForwardRef('CachedDownloader')]Β =Β None) +
+
+

Downloader for models from HuggingFace Hub. +Private models are not supported.

+
+ +Expand source code + +
class HuggingFaceCompatibleDownloader(CachedDownloader, ABC):
+    """
+    Downloader for models from HuggingFace Hub.
+    Private models are not supported.
+    """
+
+    def __init__(
+        self,
+        name: str = "Huggingface.co",
+        base_url: str = "https://huggingface.co",
+        fb_downloader: Optional["CachedDownloader"] = None,
+    ):
+        self.cache_dir = checkpoints_dir
+        """SHOULD be same for all instances to prevent downloading same model multiple times
+        Points to ~/.cache/carvekit/checkpoints"""
+        self.base_url = base_url
+        """MUST be a base url with protocol and domain name to huggingface or another, compatible in terms of models downloading API source"""
+        self._name = name
+        self._fallback_downloader = fb_downloader
+
+    @property
+    def fallback_downloader(self) -> Optional["CachedDownloader"]:
+        return self._fallback_downloader
+
+    @property
+    def name(self):
+        return self._name
+
+    def check_for_existence(self, model_name: str) -> Optional[Path]:
+        """
+        Checks if model is already downloaded and cached. Verifies file integrity by checksum.
+        Returns:
+            Optional[pathlib.Path]: Path to the cached model if cached.
+        """
+        if model_name not in MODELS_URLS.keys():
+            raise FileNotFoundError("Unknown model!")
+        path = (
+            self.cache_dir
+            / MODELS_URLS[model_name]["repository"].split("/")[1]
+            / model_name
+        )
+
+        if not path.exists():
+            return None
+
+        if MODELS_CHECKSUMS[path.name] != sha512_checksum_calc(path):
+            warnings.warn(
+                f"Invalid checksum for model {path.name}. Downloading correct model!"
+            )
+            os.remove(path)
+            return None
+        return path
+
+    def download_model_base(self, model_name: str) -> Path:
+        cached_path = self.check_for_existence(model_name)
+        if cached_path is not None:
+            return cached_path
+        else:
+            cached_path = (
+                self.cache_dir
+                / MODELS_URLS[model_name]["repository"].split("/")[1]
+                / model_name
+            )
+            cached_path.parent.mkdir(parents=True, exist_ok=True)
+            url = MODELS_URLS[model_name]
+            hugging_face_url = f"{self.base_url}/{url['repository']}/resolve/{url['revision']}/{url['filename']}"
+
+            try:
+                r = requests.get(hugging_face_url, stream=True, timeout=10)
+                if r.status_code < 400:
+                    with open(cached_path, "wb") as f:
+                        r.raw.decode_content = True
+                        for chunk in tqdm.tqdm(
+                            r,
+                            desc="Downloading " + cached_path.name + " model",
+                            colour="blue",
+                        ):
+                            f.write(chunk)
+                else:
+                    if r.status_code == 404:
+                        raise FileNotFoundError(f"Model {model_name} not found!")
+                    else:
+                        raise ConnectionError(
+                            f"Error {r.status_code} while downloading model {model_name}!"
+                        )
+            except BaseException as e:
+                if cached_path.exists():
+                    os.remove(cached_path)
+                raise ConnectionError(
+                    f"Exception caught when downloading model! "
+                    f"Model name: {cached_path.name}. Exception: {str(e)}."
+                )
+            return cached_path
+
+

Ancestors

+ +

Instance variables

+
+
var base_url
+
+

MUST be a base url with protocol and domain name to huggingface or another, compatible in terms of models downloading API source

+
+
var cache_dir
+
+

SHOULD be same for all instances to prevent downloading same model multiple times +Points to ~/.cache/carvekit/checkpoints

+
+
var name
+
+
+
+ +Expand source code + +
@property
+def name(self):
+    return self._name
+
+
+
+

Methods

+
+
+def check_for_existence(self, model_name:Β str) ‑>Β Optional[pathlib.Path] +
+
+

Checks if model is already downloaded and cached. Verifies file integrity by checksum.

+

Returns

+
+
Optional[pathlib.Path]
+
Path to the cached model if cached.
+
+
+ +Expand source code + +
def check_for_existence(self, model_name: str) -> Optional[Path]:
+    """
+    Checks if model is already downloaded and cached. Verifies file integrity by checksum.
+    Returns:
+        Optional[pathlib.Path]: Path to the cached model if cached.
+    """
+    if model_name not in MODELS_URLS.keys():
+        raise FileNotFoundError("Unknown model!")
+    path = (
+        self.cache_dir
+        / MODELS_URLS[model_name]["repository"].split("/")[1]
+        / model_name
+    )
+
+    if not path.exists():
+        return None
+
+    if MODELS_CHECKSUMS[path.name] != sha512_checksum_calc(path):
+        warnings.warn(
+            f"Invalid checksum for model {path.name}. Downloading correct model!"
+        )
+        os.remove(path)
+        return None
+    return path
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/utils/fs_utils.html b/docs/api/carvekit/utils/fs_utils.html new file mode 100644 index 0000000..e31a554 --- /dev/null +++ b/docs/api/carvekit/utils/fs_utils.html @@ -0,0 +1,156 @@ + + + + + + +carvekit.utils.fs_utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.utils.fs_utils

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+from pathlib import Path
+from PIL import Image
+import warnings
+from typing import Optional
+
+
+def save_file(output: Optional[Path], input_path: Path, image: Image.Image):
+    """
+    Saves an image to the file system
+
+    Args:
+        output (Optional[pathlib.Path]): Output path [dir or end file]
+        input_path (pathlib.Path): Input path of the image
+        image (Image.Image): Image to be saved.
+    """
+    if isinstance(output, Path) and str(output) != "none":
+        if output.is_dir() and output.exists():
+            image.save(output.joinpath(input_path.with_suffix(".png").name))
+        elif output.suffix != "":
+            if output.suffix != ".png":
+                warnings.warn(
+                    f"Only export with .png extension is supported! Your {output.suffix}"
+                    f" extension will be ignored and replaced with .png!"
+                )
+            image.save(output.with_suffix(".png"))
+        else:
+            raise ValueError("Wrong output path!")
+    elif output is None or str(output) == "none":
+        image.save(
+            input_path.with_name(
+                input_path.stem.split(".")[0] + "_bg_removed"
+            ).with_suffix(".png")
+        )
+
+
+
+
+
+
+
+

Functions

+
+
+def save_file(output:Β Optional[pathlib.Path], input_path:Β pathlib.Path, image:Β PIL.Image.Image) +
+
+

Saves an image to the file system

+

Args

+
+
output : Optional[pathlib.Path]
+
Output path [dir or end file]
+
input_path : pathlib.Path
+
Input path of the image
+
image : Image.Image
+
Image to be saved.
+
+
+ +Expand source code + +
def save_file(output: Optional[Path], input_path: Path, image: Image.Image):
+    """
+    Saves an image to the file system
+
+    Args:
+        output (Optional[pathlib.Path]): Output path [dir or end file]
+        input_path (pathlib.Path): Input path of the image
+        image (Image.Image): Image to be saved.
+    """
+    if isinstance(output, Path) and str(output) != "none":
+        if output.is_dir() and output.exists():
+            image.save(output.joinpath(input_path.with_suffix(".png").name))
+        elif output.suffix != "":
+            if output.suffix != ".png":
+                warnings.warn(
+                    f"Only export with .png extension is supported! Your {output.suffix}"
+                    f" extension will be ignored and replaced with .png!"
+                )
+            image.save(output.with_suffix(".png"))
+        else:
+            raise ValueError("Wrong output path!")
+    elif output is None or str(output) == "none":
+        image.save(
+            input_path.with_name(
+                input_path.stem.split(".")[0] + "_bg_removed"
+            ).with_suffix(".png")
+        )
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/utils/image_utils.html b/docs/api/carvekit/utils/image_utils.html new file mode 100644 index 0000000..fde17ac --- /dev/null +++ b/docs/api/carvekit/utils/image_utils.html @@ -0,0 +1,515 @@ + + + + + + +carvekit.utils.image_utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.utils.image_utils

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+    Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+    Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+    License: Apache License 2.0
+"""
+
+import pathlib
+from typing import Union, Any, Tuple
+
+import PIL.Image
+import numpy as np
+import torch
+
+ALLOWED_SUFFIXES = [".jpg", ".jpeg", ".bmp", ".png", ".webp"]
+
+
+def to_tensor(x: Any) -> torch.Tensor:
+    """
+    Returns a PIL.Image.Image as torch tensor without swap tensor dims.
+
+    Args:
+        x (PIL.Image.Image): image
+
+    Returns:
+        torch.Tensor: image as torch tensor
+    """
+    return torch.tensor(np.array(x, copy=True))
+
+
+def load_image(file: Union[str, pathlib.Path, PIL.Image.Image]) -> PIL.Image.Image:
+    """Returns a `PIL.Image.Image` class by string path or `pathlib.Path` or `PIL.Image.Image` instance
+
+    Args:
+        file (Union[str, pathlib.Path, PIL.Image.Image]): File path or `PIL.Image.Image` instance
+
+    Returns:
+        PIL.Image.Image: image instance loaded from `file` location
+
+    Raises:
+        ValueError: If file not exists or file is directory or file isn't an image or file is not correct PIL Image
+
+    """
+    if isinstance(file, str) and is_image_valid(pathlib.Path(file)):
+        return PIL.Image.open(file)
+    elif isinstance(file, PIL.Image.Image):
+        return file
+    elif isinstance(file, pathlib.Path) and is_image_valid(file):
+        return PIL.Image.open(str(file))
+    else:
+        raise ValueError("Unknown input file type")
+
+
+def convert_image(image: PIL.Image.Image, mode="RGB") -> PIL.Image.Image:
+    """Performs image conversion to correct color mode
+
+    Args:
+        image (PIL.Image.Image): `PIL.Image.Image` instance
+        mode (str, default=RGB): Color mode to convert
+
+    Returns:
+        PIL.Image.Image: converted image
+
+    Raises:
+        ValueError: If image hasn't convertable color mode, or it is too small
+    """
+    if is_image_valid(image):
+        return image.convert(mode)
+
+
+def is_image_valid(image: Union[pathlib.Path, PIL.Image.Image]) -> bool:
+    """This function performs image validation.
+
+    Args:
+        image (Union[pathlib.Path, PIL.Image.Image]): Path to the image or `PIL.Image.Image` instance being checked.
+
+    Returns:
+        bool: True if image is valid, False otherwise.
+
+    Raises:
+        ValueError: If file not a valid image path or image hasn't convertable color mode, or it is too small
+
+    """
+    if isinstance(image, pathlib.Path):
+        if not image.exists():
+            raise ValueError("File is not exists")
+        elif image.is_dir():
+            raise ValueError("File is a directory")
+        elif image.suffix.lower() not in ALLOWED_SUFFIXES:
+            raise ValueError(
+                f"Unsupported image format. Supported file formats: {', '.join(ALLOWED_SUFFIXES)}"
+            )
+    elif isinstance(image, PIL.Image.Image):
+        if not (image.size[0] > 32 and image.size[1] > 32):
+            raise ValueError("Image should be bigger then (32x32) pixels.")
+        elif image.mode not in [
+            "RGB",
+            "RGBA",
+            "L",
+        ]:
+            raise ValueError("Wrong image color mode.")
+    else:
+        raise ValueError("Unknown input file type")
+    return True
+
+
+def transparency_paste(
+    bg_img: PIL.Image.Image, fg_img: PIL.Image.Image, box=(0, 0)
+) -> PIL.Image.Image:
+    """
+    Inserts an image into another image while maintaining transparency.
+
+    Args:
+        bg_img (PIL.Image.Image): background image
+        fg_img (PIL.Image.Image): foreground image
+        box (tuple[int, int]): place to paste
+
+    Returns:
+        PIL.Image.Image: Background image with pasted foreground image at point or in the specified box
+    """
+    fg_img_trans = PIL.Image.new("RGBA", bg_img.size)
+    fg_img_trans.paste(fg_img, box, mask=fg_img)
+    new_img = PIL.Image.alpha_composite(bg_img, fg_img_trans)
+    return new_img
+
+
+def add_margin(
+    pil_img: PIL.Image.Image,
+    top: int,
+    right: int,
+    bottom: int,
+    left: int,
+    color: Tuple[int, int, int, int],
+) -> PIL.Image.Image:
+    """
+    Adds margin to the image.
+
+    Args:
+        pil_img (PIL.Image.Image): Image that needed to add margin.
+        top (int): pixels count at top side
+        right (int): pixels count at right side
+        bottom (int): pixels count at bottom side
+        left (int): pixels count at left side
+        color (Tuple[int, int, int, int]): color of margin
+
+    Returns:
+        PIL.Image.Image: Image with margin.
+    """
+    width, height = pil_img.size
+    new_width = width + right + left
+    new_height = height + top + bottom
+    # noinspection PyTypeChecker
+    result = PIL.Image.new(pil_img.mode, (new_width, new_height), color)
+    result.paste(pil_img, (left, top))
+    return result
+
+
+
+
+
+
+
+

Functions

+
+
+def add_margin(pil_img:Β PIL.Image.Image, top:Β int, right:Β int, bottom:Β int, left:Β int, color:Β Tuple[int,Β int,Β int,Β int]) ‑>Β PIL.Image.Image +
+
+

Adds margin to the image.

+

Args

+
+
pil_img : PIL.Image.Image
+
Image that needed to add margin.
+
top : int
+
pixels count at top side
+
right : int
+
pixels count at right side
+
bottom : int
+
pixels count at bottom side
+
left : int
+
pixels count at left side
+
color : Tuple[int, int, int, int]
+
color of margin
+
+

Returns

+
+
PIL.Image.Image
+
Image with margin.
+
+
+ +Expand source code + +
def add_margin(
+    pil_img: PIL.Image.Image,
+    top: int,
+    right: int,
+    bottom: int,
+    left: int,
+    color: Tuple[int, int, int, int],
+) -> PIL.Image.Image:
+    """
+    Adds margin to the image.
+
+    Args:
+        pil_img (PIL.Image.Image): Image that needed to add margin.
+        top (int): pixels count at top side
+        right (int): pixels count at right side
+        bottom (int): pixels count at bottom side
+        left (int): pixels count at left side
+        color (Tuple[int, int, int, int]): color of margin
+
+    Returns:
+        PIL.Image.Image: Image with margin.
+    """
+    width, height = pil_img.size
+    new_width = width + right + left
+    new_height = height + top + bottom
+    # noinspection PyTypeChecker
+    result = PIL.Image.new(pil_img.mode, (new_width, new_height), color)
+    result.paste(pil_img, (left, top))
+    return result
+
+
+
+def convert_image(image:Β PIL.Image.Image, mode='RGB') ‑>Β PIL.Image.Image +
+
+

Performs image conversion to correct color mode

+

Args

+
+
image : PIL.Image.Image
+
PIL.Image.Image instance
+
mode : str, default=RGB
+
Color mode to convert
+
+

Returns

+
+
PIL.Image.Image
+
converted image
+
+

Raises

+
+
ValueError
+
If image hasn't convertable color mode, or it is too small
+
+
+ +Expand source code + +
def convert_image(image: PIL.Image.Image, mode="RGB") -> PIL.Image.Image:
+    """Performs image conversion to correct color mode
+
+    Args:
+        image (PIL.Image.Image): `PIL.Image.Image` instance
+        mode (str, default=RGB): Color mode to convert
+
+    Returns:
+        PIL.Image.Image: converted image
+
+    Raises:
+        ValueError: If image hasn't convertable color mode, or it is too small
+    """
+    if is_image_valid(image):
+        return image.convert(mode)
+
+
+
+def is_image_valid(image:Β Union[pathlib.Path,Β PIL.Image.Image]) ‑>Β bool +
+
+

This function performs image validation.

+

Args

+
+
image : Union[pathlib.Path, PIL.Image.Image]
+
Path to the image or PIL.Image.Image instance being checked.
+
+

Returns

+
+
bool
+
True if image is valid, False otherwise.
+
+

Raises

+
+
ValueError
+
If file not a valid image path or image hasn't convertable color mode, or it is too small
+
+
+ +Expand source code + +
def is_image_valid(image: Union[pathlib.Path, PIL.Image.Image]) -> bool:
+    """This function performs image validation.
+
+    Args:
+        image (Union[pathlib.Path, PIL.Image.Image]): Path to the image or `PIL.Image.Image` instance being checked.
+
+    Returns:
+        bool: True if image is valid, False otherwise.
+
+    Raises:
+        ValueError: If file not a valid image path or image hasn't convertable color mode, or it is too small
+
+    """
+    if isinstance(image, pathlib.Path):
+        if not image.exists():
+            raise ValueError("File is not exists")
+        elif image.is_dir():
+            raise ValueError("File is a directory")
+        elif image.suffix.lower() not in ALLOWED_SUFFIXES:
+            raise ValueError(
+                f"Unsupported image format. Supported file formats: {', '.join(ALLOWED_SUFFIXES)}"
+            )
+    elif isinstance(image, PIL.Image.Image):
+        if not (image.size[0] > 32 and image.size[1] > 32):
+            raise ValueError("Image should be bigger then (32x32) pixels.")
+        elif image.mode not in [
+            "RGB",
+            "RGBA",
+            "L",
+        ]:
+            raise ValueError("Wrong image color mode.")
+    else:
+        raise ValueError("Unknown input file type")
+    return True
+
+
+
+def load_image(file:Β Union[str,Β pathlib.Path,Β PIL.Image.Image]) ‑>Β PIL.Image.Image +
+
+

Returns a PIL.Image.Image class by string path or pathlib.Path or PIL.Image.Image instance

+

Args

+
+
file : Union[str, pathlib.Path, PIL.Image.Image]
+
File path or PIL.Image.Image instance
+
+

Returns

+
+
PIL.Image.Image
+
image instance loaded from file location
+
+

Raises

+
+
ValueError
+
If file not exists or file is directory or file isn't an image or file is not correct PIL Image
+
+
+ +Expand source code + +
def load_image(file: Union[str, pathlib.Path, PIL.Image.Image]) -> PIL.Image.Image:
+    """Returns a `PIL.Image.Image` class by string path or `pathlib.Path` or `PIL.Image.Image` instance
+
+    Args:
+        file (Union[str, pathlib.Path, PIL.Image.Image]): File path or `PIL.Image.Image` instance
+
+    Returns:
+        PIL.Image.Image: image instance loaded from `file` location
+
+    Raises:
+        ValueError: If file not exists or file is directory or file isn't an image or file is not correct PIL Image
+
+    """
+    if isinstance(file, str) and is_image_valid(pathlib.Path(file)):
+        return PIL.Image.open(file)
+    elif isinstance(file, PIL.Image.Image):
+        return file
+    elif isinstance(file, pathlib.Path) and is_image_valid(file):
+        return PIL.Image.open(str(file))
+    else:
+        raise ValueError("Unknown input file type")
+
+
+
+def to_tensor(x:Β Any) ‑>Β torch.Tensor +
+
+

Returns a PIL.Image.Image as torch tensor without swap tensor dims.

+

Args

+
+
x : PIL.Image.Image
+
image
+
+

Returns

+
+
torch.Tensor
+
image as torch tensor
+
+
+ +Expand source code + +
def to_tensor(x: Any) -> torch.Tensor:
+    """
+    Returns a PIL.Image.Image as torch tensor without swap tensor dims.
+
+    Args:
+        x (PIL.Image.Image): image
+
+    Returns:
+        torch.Tensor: image as torch tensor
+    """
+    return torch.tensor(np.array(x, copy=True))
+
+
+
+def transparency_paste(bg_img:Β PIL.Image.Image, fg_img:Β PIL.Image.Image, box=(0, 0)) ‑>Β PIL.Image.Image +
+
+

Inserts an image into another image while maintaining transparency.

+

Args

+
+
bg_img : PIL.Image.Image
+
background image
+
fg_img : PIL.Image.Image
+
foreground image
+
box : tuple[int, int]
+
place to paste
+
+

Returns

+
+
PIL.Image.Image
+
Background image with pasted foreground image at point or in the specified box
+
+
+ +Expand source code + +
def transparency_paste(
+    bg_img: PIL.Image.Image, fg_img: PIL.Image.Image, box=(0, 0)
+) -> PIL.Image.Image:
+    """
+    Inserts an image into another image while maintaining transparency.
+
+    Args:
+        bg_img (PIL.Image.Image): background image
+        fg_img (PIL.Image.Image): foreground image
+        box (tuple[int, int]): place to paste
+
+    Returns:
+        PIL.Image.Image: Background image with pasted foreground image at point or in the specified box
+    """
+    fg_img_trans = PIL.Image.new("RGBA", bg_img.size)
+    fg_img_trans.paste(fg_img, box, mask=fg_img)
+    new_img = PIL.Image.alpha_composite(bg_img, fg_img_trans)
+    return new_img
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/utils/index.html b/docs/api/carvekit/utils/index.html new file mode 100644 index 0000000..851f59a --- /dev/null +++ b/docs/api/carvekit/utils/index.html @@ -0,0 +1,92 @@ + + + + + + +carvekit.utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.utils

+
+
+
+
+

Sub-modules

+
+
carvekit.utils.download_models
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: …

+
+
carvekit.utils.fs_utils
+
+ +
+
carvekit.utils.image_utils
+
+ +
+
carvekit.utils.mask_utils
+
+ +
+
carvekit.utils.models_utils
+
+ +
+
carvekit.utils.pool_utils
+
+ +
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/utils/mask_utils.html b/docs/api/carvekit/utils/mask_utils.html new file mode 100644 index 0000000..4eb1ca2 --- /dev/null +++ b/docs/api/carvekit/utils/mask_utils.html @@ -0,0 +1,303 @@ + + + + + + +carvekit.utils.mask_utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.utils.mask_utils

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+import PIL.Image
+import torch
+from carvekit.utils.image_utils import to_tensor
+
+
+def composite(
+    foreground: PIL.Image.Image,
+    background: PIL.Image.Image,
+    alpha: PIL.Image.Image,
+    device="cpu",
+):
+    """
+    Composites foreground with background by following
+    https://pymatting.github.io/intro.html#alpha-matting math formula.
+
+    Args:
+        foreground (PIL.Image.Image): Image that will be pasted to background image with following alpha mask.
+        background (PIL.Image.Image): Background image
+        alpha (PIL.Image.Image): Alpha Image
+        device (Literal[cpu, cuda]): Processing device
+
+    Returns:
+        PIL.Image.Image: Composited image.
+    """
+
+    foreground = foreground.convert("RGBA")
+    background = background.convert("RGBA")
+    alpha_rgba = alpha.convert("RGBA")
+    alpha_l = alpha.convert("L")
+
+    fg = to_tensor(foreground).to(device)
+    alpha_rgba = to_tensor(alpha_rgba).to(device)
+    alpha_l = to_tensor(alpha_l).to(device)
+    bg = to_tensor(background).to(device)
+
+    alpha_l = alpha_l / 255
+    alpha_rgba = alpha_rgba / 255
+
+    bg = torch.where(torch.logical_not(alpha_rgba >= 1), bg, fg)
+    bg[:, :, 0] = alpha_l[:, :] * fg[:, :, 0] + (1 - alpha_l[:, :]) * bg[:, :, 0]
+    bg[:, :, 1] = alpha_l[:, :] * fg[:, :, 1] + (1 - alpha_l[:, :]) * bg[:, :, 1]
+    bg[:, :, 2] = alpha_l[:, :] * fg[:, :, 2] + (1 - alpha_l[:, :]) * bg[:, :, 2]
+    bg[:, :, 3] = alpha_l[:, :] * 255
+
+    del alpha_l, alpha_rgba, fg
+    return PIL.Image.fromarray(bg.cpu().numpy()).convert("RGBA")
+
+
+def apply_mask(
+    image: PIL.Image.Image, mask: PIL.Image.Image, device="cpu"
+) -> PIL.Image.Image:
+    """
+    Applies mask to foreground.
+
+    Args:
+        image (PIL.Image.Image): Image with background.
+        mask (PIL.Image.Image): Alpha Channel mask for this image.
+        device (Literal[cpu, cuda]): Processing device.
+
+    Returns:
+        PIL.Image.Image: Image without background, where mask was black.
+    """
+    background = PIL.Image.new("RGBA", image.size, color=(130, 130, 130, 0))
+    return composite(image, background, mask, device=device).convert("RGBA")
+
+
+def extract_alpha_channel(image: PIL.Image.Image) -> PIL.Image.Image:
+    """
+    Extracts alpha channel from the RGBA image.
+
+    Args:
+        image: RGBA PIL image
+
+    Returns:
+        PIL.Image.Image: RGBA alpha channel image
+    """
+    alpha = image.split()[-1]
+    bg = PIL.Image.new("RGBA", image.size, (0, 0, 0, 255))
+    bg.paste(alpha, mask=alpha)
+    return bg.convert("RGBA")
+
+
+
+
+
+
+
+

Functions

+
+
+def apply_mask(image:Β PIL.Image.Image, mask:Β PIL.Image.Image, device='cpu') ‑>Β PIL.Image.Image +
+
+

Applies mask to foreground.

+

Args

+
+
image : PIL.Image.Image
+
Image with background.
+
mask : PIL.Image.Image
+
Alpha Channel mask for this image.
+
device : Literal[cpu, cuda]
+
Processing device.
+
+

Returns

+
+
PIL.Image.Image
+
Image without background, where mask was black.
+
+
+ +Expand source code + +
def apply_mask(
+    image: PIL.Image.Image, mask: PIL.Image.Image, device="cpu"
+) -> PIL.Image.Image:
+    """
+    Applies mask to foreground.
+
+    Args:
+        image (PIL.Image.Image): Image with background.
+        mask (PIL.Image.Image): Alpha Channel mask for this image.
+        device (Literal[cpu, cuda]): Processing device.
+
+    Returns:
+        PIL.Image.Image: Image without background, where mask was black.
+    """
+    background = PIL.Image.new("RGBA", image.size, color=(130, 130, 130, 0))
+    return composite(image, background, mask, device=device).convert("RGBA")
+
+
+
+def composite(foreground:Β PIL.Image.Image, background:Β PIL.Image.Image, alpha:Β PIL.Image.Image, device='cpu') +
+
+

Composites foreground with background by following +https://pymatting.github.io/intro.html#alpha-matting math formula.

+

Args

+
+
foreground : PIL.Image.Image
+
Image that will be pasted to background image with following alpha mask.
+
background : PIL.Image.Image
+
Background image
+
alpha : PIL.Image.Image
+
Alpha Image
+
device : Literal[cpu, cuda]
+
Processing device
+
+

Returns

+
+
PIL.Image.Image
+
Composited image.
+
+
+ +Expand source code + +
def composite(
+    foreground: PIL.Image.Image,
+    background: PIL.Image.Image,
+    alpha: PIL.Image.Image,
+    device="cpu",
+):
+    """
+    Composites foreground with background by following
+    https://pymatting.github.io/intro.html#alpha-matting math formula.
+
+    Args:
+        foreground (PIL.Image.Image): Image that will be pasted to background image with following alpha mask.
+        background (PIL.Image.Image): Background image
+        alpha (PIL.Image.Image): Alpha Image
+        device (Literal[cpu, cuda]): Processing device
+
+    Returns:
+        PIL.Image.Image: Composited image.
+    """
+
+    foreground = foreground.convert("RGBA")
+    background = background.convert("RGBA")
+    alpha_rgba = alpha.convert("RGBA")
+    alpha_l = alpha.convert("L")
+
+    fg = to_tensor(foreground).to(device)
+    alpha_rgba = to_tensor(alpha_rgba).to(device)
+    alpha_l = to_tensor(alpha_l).to(device)
+    bg = to_tensor(background).to(device)
+
+    alpha_l = alpha_l / 255
+    alpha_rgba = alpha_rgba / 255
+
+    bg = torch.where(torch.logical_not(alpha_rgba >= 1), bg, fg)
+    bg[:, :, 0] = alpha_l[:, :] * fg[:, :, 0] + (1 - alpha_l[:, :]) * bg[:, :, 0]
+    bg[:, :, 1] = alpha_l[:, :] * fg[:, :, 1] + (1 - alpha_l[:, :]) * bg[:, :, 1]
+    bg[:, :, 2] = alpha_l[:, :] * fg[:, :, 2] + (1 - alpha_l[:, :]) * bg[:, :, 2]
+    bg[:, :, 3] = alpha_l[:, :] * 255
+
+    del alpha_l, alpha_rgba, fg
+    return PIL.Image.fromarray(bg.cpu().numpy()).convert("RGBA")
+
+
+
+def extract_alpha_channel(image:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Extracts alpha channel from the RGBA image.

+

Args

+
+
image
+
RGBA PIL image
+
+

Returns

+
+
PIL.Image.Image
+
RGBA alpha channel image
+
+
+ +Expand source code + +
def extract_alpha_channel(image: PIL.Image.Image) -> PIL.Image.Image:
+    """
+    Extracts alpha channel from the RGBA image.
+
+    Args:
+        image: RGBA PIL image
+
+    Returns:
+        PIL.Image.Image: RGBA alpha channel image
+    """
+    alpha = image.split()[-1]
+    bg = PIL.Image.new("RGBA", image.size, (0, 0, 0, 255))
+    bg.paste(alpha, mask=alpha)
+    return bg.convert("RGBA")
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/utils/models_utils.html b/docs/api/carvekit/utils/models_utils.html new file mode 100644 index 0000000..4002420 --- /dev/null +++ b/docs/api/carvekit/utils/models_utils.html @@ -0,0 +1,399 @@ + + + + + + +carvekit.utils.models_utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.utils.models_utils

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+
+import random
+import warnings
+from typing import Union, Tuple, Any
+
+import torch
+from torch import autocast
+
+
+class EmptyAutocast(object):
+    """
+    Empty class for any auto-casting disabling.
+    """
+
+    def __enter__(self):
+        return None
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        return
+
+    def __call__(self, func):
+        return
+
+
+def get_precision_autocast(
+    device="cpu", fp16=True, override_dtype=None
+) -> Union[
+    Tuple[EmptyAutocast, Union[torch.dtype, Any]],
+    Tuple[autocast, Union[torch.dtype, Any]],
+]:
+    """
+    Returns precision and auto-cast settings for given device and fp16 settings.
+
+    Args:
+        device (Literal[cpu, cuda]): Device to get precision and auto-cast settings for.
+        fp16 (bool): Whether to use fp16 precision.
+        override_dtype (bool): Override dtype for auto-cast.
+
+    Returns:
+        Union[Tuple[EmptyAutocast, Union[torch.dtype, Any]],Tuple[autocast, Union[torch.dtype, Any]]]: Autocast object, dtype
+    """
+    dtype = torch.float32
+    cache_enabled = None
+
+    if device == "cpu" and fp16:
+        warnings.warn("FP16 is not supported on CPU. Using FP32 instead.")
+        dtype = torch.float32
+
+        # TODO: Implement BFP16 on CPU. There are unexpected slowdowns on cpu on a clean environment.
+        # warnings.warn(
+        #     "Accuracy BFP16 has experimental support on the CPU. "
+        #     "This may result in an unexpected reduction in quality."
+        # )
+        # dtype = (
+        #     torch.bfloat16
+        # )  # Using bfloat16 for CPU, since autocast is not supported for float16
+
+    if "cuda" in device and fp16:
+        dtype = torch.float16
+        cache_enabled = True
+
+    if override_dtype is not None:
+        dtype = override_dtype
+
+    if dtype == torch.float32 and device == "cpu":
+        return EmptyAutocast(), dtype
+
+    return (
+        torch.autocast(
+            device_type=device, dtype=dtype, enabled=True, cache_enabled=cache_enabled
+        ),
+        dtype,
+    )
+
+
+def cast_network(network: torch.nn.Module, dtype: torch.dtype):
+    """
+    Cast network to given dtype
+
+    Args:
+        network (torch.nn.Module): Network to be casted
+        dtype (torch.dtype): Dtype to cast network to
+    """
+    if dtype == torch.float16:
+        network.half()
+    elif dtype == torch.bfloat16:
+        network.bfloat16()
+    elif dtype == torch.float32:
+        network.float()
+    else:
+        raise ValueError(f"Unknown dtype {dtype}")
+
+
+def fix_seed(seed: int = 42):
+    """
+    Sets fixed random seed
+
+    Args:
+        seed (int, default=42): Random seed to be set
+    """
+    random.seed(seed)
+    torch.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        # noinspection PyUnresolvedReferences
+        torch.backends.cudnn.deterministic = True
+        # noinspection PyUnresolvedReferences
+        torch.backends.cudnn.benchmark = False
+    return True
+
+
+def suppress_warnings():
+    # Suppress PyTorch 1.11.0 warning associated with changing order of args in nn.MaxPool2d layer,
+    # since source code is not affected by this issue and there aren't any other correct way to hide this message.
+    warnings.filterwarnings(
+        "ignore",
+        category=UserWarning,
+        message="Note that order of the arguments: ceil_mode and "
+        "return_indices will changeto match the args list "
+        "in nn.MaxPool2d in a future release.",
+        module="torch",
+    )
+
+
+
+
+
+
+
+

Functions

+
+
+def cast_network(network:Β torch.nn.modules.module.Module, dtype:Β torch.dtype) +
+
+

Cast network to given dtype

+

Args

+
+
network : torch.nn.Module
+
Network to be casted
+
dtype : torch.dtype
+
Dtype to cast network to
+
+
+ +Expand source code + +
def cast_network(network: torch.nn.Module, dtype: torch.dtype):
+    """
+    Cast network to given dtype
+
+    Args:
+        network (torch.nn.Module): Network to be casted
+        dtype (torch.dtype): Dtype to cast network to
+    """
+    if dtype == torch.float16:
+        network.half()
+    elif dtype == torch.bfloat16:
+        network.bfloat16()
+    elif dtype == torch.float32:
+        network.float()
+    else:
+        raise ValueError(f"Unknown dtype {dtype}")
+
+
+
+def fix_seed(seed:Β intΒ =Β 42) +
+
+

Sets fixed random seed

+

Args

+
+
seed : int, default=42
+
Random seed to be set
+
+
+ +Expand source code + +
def fix_seed(seed: int = 42):
+    """
+    Sets fixed random seed
+
+    Args:
+        seed (int, default=42): Random seed to be set
+    """
+    random.seed(seed)
+    torch.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        # noinspection PyUnresolvedReferences
+        torch.backends.cudnn.deterministic = True
+        # noinspection PyUnresolvedReferences
+        torch.backends.cudnn.benchmark = False
+    return True
+
+
+
+def get_precision_autocast(device='cpu', fp16=True, override_dtype=None) ‑>Β Union[Tuple[EmptyAutocast,Β Union[torch.dtype,Β Any]],Β Tuple[torch.autocast_mode.autocast,Β Union[torch.dtype,Β Any]]] +
+
+

Returns precision and auto-cast settings for given device and fp16 settings.

+

Args

+
+
device : Literal[cpu, cuda]
+
Device to get precision and auto-cast settings for.
+
fp16 : bool
+
Whether to use fp16 precision.
+
override_dtype : bool
+
Override dtype for auto-cast.
+
+

Returns

+
+
Union[Tuple[EmptyAutocast, Union[torch.dtype, Any]],Tuple[autocast, Union[torch.dtype, Any]]]
+
Autocast object, dtype
+
+
+ +Expand source code + +
def get_precision_autocast(
+    device="cpu", fp16=True, override_dtype=None
+) -> Union[
+    Tuple[EmptyAutocast, Union[torch.dtype, Any]],
+    Tuple[autocast, Union[torch.dtype, Any]],
+]:
+    """
+    Returns precision and auto-cast settings for given device and fp16 settings.
+
+    Args:
+        device (Literal[cpu, cuda]): Device to get precision and auto-cast settings for.
+        fp16 (bool): Whether to use fp16 precision.
+        override_dtype (bool): Override dtype for auto-cast.
+
+    Returns:
+        Union[Tuple[EmptyAutocast, Union[torch.dtype, Any]],Tuple[autocast, Union[torch.dtype, Any]]]: Autocast object, dtype
+    """
+    dtype = torch.float32
+    cache_enabled = None
+
+    if device == "cpu" and fp16:
+        warnings.warn("FP16 is not supported on CPU. Using FP32 instead.")
+        dtype = torch.float32
+
+        # TODO: Implement BFP16 on CPU. There are unexpected slowdowns on cpu on a clean environment.
+        # warnings.warn(
+        #     "Accuracy BFP16 has experimental support on the CPU. "
+        #     "This may result in an unexpected reduction in quality."
+        # )
+        # dtype = (
+        #     torch.bfloat16
+        # )  # Using bfloat16 for CPU, since autocast is not supported for float16
+
+    if "cuda" in device and fp16:
+        dtype = torch.float16
+        cache_enabled = True
+
+    if override_dtype is not None:
+        dtype = override_dtype
+
+    if dtype == torch.float32 and device == "cpu":
+        return EmptyAutocast(), dtype
+
+    return (
+        torch.autocast(
+            device_type=device, dtype=dtype, enabled=True, cache_enabled=cache_enabled
+        ),
+        dtype,
+    )
+
+
+
+def suppress_warnings() +
+
+
+
+ +Expand source code + +
def suppress_warnings():
+    # Suppress PyTorch 1.11.0 warning associated with changing order of args in nn.MaxPool2d layer,
+    # since source code is not affected by this issue and there aren't any other correct way to hide this message.
+    warnings.filterwarnings(
+        "ignore",
+        category=UserWarning,
+        message="Note that order of the arguments: ceil_mode and "
+        "return_indices will changeto match the args list "
+        "in nn.MaxPool2d in a future release.",
+        module="torch",
+    )
+
+
+
+
+
+

Classes

+
+
+class EmptyAutocast +
+
+

Empty class for any auto-casting disabling.

+
+ +Expand source code + +
class EmptyAutocast(object):
+    """
+    Empty class for any auto-casting disabling.
+    """
+
+    def __enter__(self):
+        return None
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        return
+
+    def __call__(self, func):
+        return
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/utils/pool_utils.html b/docs/api/carvekit/utils/pool_utils.html new file mode 100644 index 0000000..abcae5f --- /dev/null +++ b/docs/api/carvekit/utils/pool_utils.html @@ -0,0 +1,189 @@ + + + + + + +carvekit.utils.pool_utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.utils.pool_utils

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Iterable, Callable, Collection, List
+
+
+def thread_pool_processing(func: Callable[[Any], Any], data: Iterable, workers=18):
+    """
+    Passes all iterator data through the given function
+
+    Args:
+        workers (int, default=18): Count of workers.
+        func (Callable[[Any], Any]): function to pass data through
+        data (Iterable): input iterator
+
+    Returns:
+        List[Any]: list of results
+
+    """
+    with ThreadPoolExecutor(workers) as p:
+        return list(p.map(func, data))
+
+
+def batch_generator(iterable: Collection, n: int = 1) -> Iterable[Collection]:
+    """
+    Splits any iterable into n-size packets
+
+    Args:
+        iterable (Collection): iterator
+        n (int, default=1): size of packets
+
+    Returns:
+        Iterable[Collection]: new n-size packet
+    """
+    it = len(iterable)
+    for ndx in range(0, it, n):
+        yield iterable[ndx : min(ndx + n, it)]
+
+
+
+
+
+
+
+

Functions

+
+
+def batch_generator(iterable:Β Collection, n:Β intΒ =Β 1) ‑>Β Iterable[Collection] +
+
+

Splits any iterable into n-size packets

+

Args

+
+
iterable : Collection
+
iterator
+
n : int, default=1
+
size of packets
+
+

Returns

+
+
Iterable[Collection]
+
new n-size packet
+
+
+ +Expand source code + +
def batch_generator(iterable: Collection, n: int = 1) -> Iterable[Collection]:
+    """
+    Splits any iterable into n-size packets
+
+    Args:
+        iterable (Collection): iterator
+        n (int, default=1): size of packets
+
+    Returns:
+        Iterable[Collection]: new n-size packet
+    """
+    it = len(iterable)
+    for ndx in range(0, it, n):
+        yield iterable[ndx : min(ndx + n, it)]
+
+
+
+def thread_pool_processing(func:Β Callable[[Any],Β Any], data:Β Iterable, workers=18) +
+
+

Passes all iterator data through the given function

+

Args

+
+
workers : int, default=18
+
Count of workers.
+
func : Callable[[Any], Any]
+
function to pass data through
+
data : Iterable
+
input iterator
+
+

Returns

+
+
List[Any]
+
list of results
+
+
+ +Expand source code + +
def thread_pool_processing(func: Callable[[Any], Any], data: Iterable, workers=18):
+    """
+    Passes all iterator data through the given function
+
+    Args:
+        workers (int, default=18): Count of workers.
+        func (Callable[[Any], Any]): function to pass data through
+        data (Iterable): input iterator
+
+    Returns:
+        List[Any]: list of results
+
+    """
+    with ThreadPoolExecutor(workers) as p:
+        return list(p.map(func, data))
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/app.html b/docs/api/carvekit/web/app.html new file mode 100644 index 0000000..72495e0 --- /dev/null +++ b/docs/api/carvekit/web/app.html @@ -0,0 +1,88 @@ + + + + + + +carvekit.web.app API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.app

+
+
+
+ +Expand source code + +
from pathlib import Path
+
+import uvicorn
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from starlette.staticfiles import StaticFiles
+
+from carvekit import version
+from carvekit.web.deps import config
+from carvekit.web.routers.api_router import api_router
+
+app = FastAPI(title="CarveKit Web API", version=version)
+
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+app.include_router(api_router, prefix="/api")
+app.mount(
+    "/",
+    StaticFiles(directory=Path(__file__).parent.joinpath("static"), html=True),
+    name="static",
+)
+
+if __name__ == "__main__":
+    uvicorn.run(app, host=config.host, port=config.port)
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/deps.html b/docs/api/carvekit/web/deps.html new file mode 100644 index 0000000..e8b94c3 --- /dev/null +++ b/docs/api/carvekit/web/deps.html @@ -0,0 +1,64 @@ + + + + + + +carvekit.web.deps API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.deps

+
+
+
+ +Expand source code + +
from carvekit.web.schemas.config import WebAPIConfig
+from carvekit.web.utils.init_utils import init_config
+from carvekit.web.utils.task_queue import MLProcessor
+
+config: WebAPIConfig = init_config()
+ml_processor = MLProcessor(api_config=config)
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/handlers/index.html b/docs/api/carvekit/web/handlers/index.html new file mode 100644 index 0000000..82a3996 --- /dev/null +++ b/docs/api/carvekit/web/handlers/index.html @@ -0,0 +1,65 @@ + + + + + + +carvekit.web.handlers API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.handlers

+
+
+
+
+

Sub-modules

+
+
carvekit.web.handlers.response
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/handlers/response.html b/docs/api/carvekit/web/handlers/response.html new file mode 100644 index 0000000..ced0e8f --- /dev/null +++ b/docs/api/carvekit/web/handlers/response.html @@ -0,0 +1,205 @@ + + + + + + +carvekit.web.handlers.response API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.handlers.response

+
+
+
+ +Expand source code + +
from typing import Union
+
+from fastapi import Header
+from fastapi.responses import Response, JSONResponse
+from carvekit.web.deps import config
+
+
+def Authenticate(x_api_key: Union[str, None] = Header(None)) -> Union[bool, str]:
+    if x_api_key in config.auth.allowed_tokens:
+        return "allowed"
+    elif x_api_key == config.auth.admin_token:
+        return "admin"
+    elif config.auth.auth is False:
+        return "allowed"
+    else:
+        return False
+
+
+def handle_response(response, original_image) -> Response:
+    """
+    Response handler from TaskQueue
+    :param response: TaskQueue response
+    :param original_image: Original PIL image
+    :return: Complete flask response
+    """
+    response_object = None
+    if isinstance(response, dict):
+        if response["type"] == "jpg":
+            response_object = Response(
+                content=response["data"][0].read(), media_type="image/jpeg"
+            )
+        elif response["type"] == "png":
+            response_object = Response(
+                content=response["data"][0].read(), media_type="image/png"
+            )
+        elif response["type"] == "zip":
+            response_object = Response(
+                content=response["data"][0], media_type="application/zip"
+            )
+            response_object.headers[
+                "Content-Disposition"
+            ] = "attachment; filename='no-bg.zip'"
+
+        # Add headers to output result
+        response_object.headers["X-Credits-Charged"] = "0"
+        response_object.headers["X-Type"] = "other"  # TODO Make support for this
+        response_object.headers["X-Max-Width"] = str(original_image.size[0])
+        response_object.headers["X-Max-Height"] = str(original_image.size[1])
+        response_object.headers[
+            "X-Ratelimit-Limit"
+        ] = "500"  # TODO Make ratelimit support
+        response_object.headers["X-Ratelimit-Remaining"] = "500"
+        response_object.headers["X-Ratelimit-Reset"] = "1"
+        response_object.headers["X-Width"] = str(response["data"][1][0])
+        response_object.headers["X-Height"] = str(response["data"][1][1])
+
+    else:
+        response = JSONResponse(content=response[0])
+        response.headers["X-Credits-Charged"] = "0"
+
+    return response_object
+
+
+
+
+
+
+
+

Functions

+
+
+def Authenticate(x_api_key:Β Optional[str]Β =Β Header(None)) ‑>Β Union[bool,Β str] +
+
+
+
+ +Expand source code + +
def Authenticate(x_api_key: Union[str, None] = Header(None)) -> Union[bool, str]:
+    if x_api_key in config.auth.allowed_tokens:
+        return "allowed"
+    elif x_api_key == config.auth.admin_token:
+        return "admin"
+    elif config.auth.auth is False:
+        return "allowed"
+    else:
+        return False
+
+
+
+def handle_response(response, original_image) ‑>Β starlette.responses.Response +
+
+

Response handler from TaskQueue +:param response: TaskQueue response +:param original_image: Original PIL image +:return: Complete flask response

+
+ +Expand source code + +
def handle_response(response, original_image) -> Response:
+    """
+    Response handler from TaskQueue
+    :param response: TaskQueue response
+    :param original_image: Original PIL image
+    :return: Complete flask response
+    """
+    response_object = None
+    if isinstance(response, dict):
+        if response["type"] == "jpg":
+            response_object = Response(
+                content=response["data"][0].read(), media_type="image/jpeg"
+            )
+        elif response["type"] == "png":
+            response_object = Response(
+                content=response["data"][0].read(), media_type="image/png"
+            )
+        elif response["type"] == "zip":
+            response_object = Response(
+                content=response["data"][0], media_type="application/zip"
+            )
+            response_object.headers[
+                "Content-Disposition"
+            ] = "attachment; filename='no-bg.zip'"
+
+        # Add headers to output result
+        response_object.headers["X-Credits-Charged"] = "0"
+        response_object.headers["X-Type"] = "other"  # TODO Make support for this
+        response_object.headers["X-Max-Width"] = str(original_image.size[0])
+        response_object.headers["X-Max-Height"] = str(original_image.size[1])
+        response_object.headers[
+            "X-Ratelimit-Limit"
+        ] = "500"  # TODO Make ratelimit support
+        response_object.headers["X-Ratelimit-Remaining"] = "500"
+        response_object.headers["X-Ratelimit-Reset"] = "1"
+        response_object.headers["X-Width"] = str(response["data"][1][0])
+        response_object.headers["X-Height"] = str(response["data"][1][1])
+
+    else:
+        response = JSONResponse(content=response[0])
+        response.headers["X-Credits-Charged"] = "0"
+
+    return response_object
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/index.html b/docs/api/carvekit/web/index.html new file mode 100644 index 0000000..59bd3d6 --- /dev/null +++ b/docs/api/carvekit/web/index.html @@ -0,0 +1,100 @@ + + + + + + +carvekit.web API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web

+
+
+
+
+

Sub-modules

+
+
carvekit.web.app
+
+
+
+
carvekit.web.deps
+
+
+
+
carvekit.web.handlers
+
+
+
+
carvekit.web.other
+
+
+
+
carvekit.web.responses
+
+
+
+
carvekit.web.routers
+
+
+
+
carvekit.web.schemas
+
+
+
+
carvekit.web.utils
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/other/index.html b/docs/api/carvekit/web/other/index.html new file mode 100644 index 0000000..6d6bdb0 --- /dev/null +++ b/docs/api/carvekit/web/other/index.html @@ -0,0 +1,65 @@ + + + + + + +carvekit.web.other API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.other

+
+
+
+
+

Sub-modules

+
+
carvekit.web.other.removebg
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/other/removebg.html b/docs/api/carvekit/web/other/removebg.html new file mode 100644 index 0000000..e657a22 --- /dev/null +++ b/docs/api/carvekit/web/other/removebg.html @@ -0,0 +1,571 @@ + + + + + + +carvekit.web.other.removebg API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.other.removebg

+
+
+
+ +Expand source code + +
import io
+import time
+import zipfile
+
+import requests
+from PIL import Image, ImageColor
+
+from carvekit.utils.image_utils import transparency_paste, add_margin
+from carvekit.utils.mask_utils import extract_alpha_channel
+from carvekit.web.responses.api import error_dict
+from carvekit.api.interface import Interface
+
+
+def process_remove_bg(
+    interface: Interface, params, image, bg, is_json_or_www_encoded=False
+):
+    """
+    Handles a request to the removebg api method
+
+    Args:
+        interface: CarveKit interface
+        bg: background pil image
+        is_json_or_www_encoded: is "json" or "x-www-form-urlencoded" content-type
+        image: foreground pil image
+        params: parameters
+    """
+    h, w = image.size
+    if h < 2 or w < 2:
+        return error_dict("Image is too small. Minimum size 2x2"), 400
+
+    if "size" in params.keys():
+        value = params["size"]
+        if value == "preview" or value == "small" or value == "regular":
+            image.thumbnail((625, 400), resample=3)  # 0.25 mp
+        elif value == "medium":
+            image.thumbnail((1504, 1000), resample=3)  # 1.5 mp
+        elif value == "hd":
+            image.thumbnail((2000, 2000), resample=3)  # 2.5 mp
+        else:
+            image.thumbnail((6250, 4000), resample=3)  # 25 mp
+
+    roi_box = [0, 0, image.size[0], image.size[1]]
+    if "type" in params.keys():
+        value = params["type"]
+        pass
+
+    if "roi" in params.keys():
+        value = params["roi"].split(" ")
+        if len(value) == 4:
+            for i, coord in enumerate(value):
+                if "px" in coord:
+                    coord = coord.replace("px", "")
+                    try:
+                        coord = int(coord)
+                    except BaseException:
+                        return (
+                            error_dict(
+                                "Error converting roi coordinate string to number!"
+                            ),
+                            400,
+                        )
+                    if coord < 0:
+                        error_dict("Bad roi coordinate."), 400
+                    if (i == 0 or i == 2) and coord > image.size[0]:
+                        return (
+                            error_dict(
+                                "The roi coordinate cannot be larger than the image size."
+                            ),
+                            400,
+                        )
+                    elif (i == 1 or i == 3) and coord > image.size[1]:
+                        return (
+                            error_dict(
+                                "The roi coordinate cannot be larger than the image size."
+                            ),
+                            400,
+                        )
+                    roi_box[i] = int(coord)
+                elif "%" in coord:
+                    coord = coord.replace("%", "")
+                    try:
+                        coord = int(coord)
+                    except BaseException:
+                        return (
+                            error_dict(
+                                "Error converting roi coordinate string to number!"
+                            ),
+                            400,
+                        )
+                    if coord > 100:
+                        return (
+                            error_dict("The coordinate cannot be more than 100%"),
+                            400,
+                        )
+                    elif coord < 0:
+                        return error_dict("Coordinate cannot be less than 0%"), 400
+                    if i == 0 or i == 2:
+                        coord = int(image.size[0] * coord / 100)
+                    elif i == 1 or i == 3:
+                        coord = int(image.size[1] * coord / 100)
+                    roi_box[i] = coord
+                else:
+                    return error_dict("Something wrong with roi coordinates!"), 400
+
+    new_image = image.copy()
+    new_image = new_image.crop(roi_box)
+    h, w = new_image.size
+    if h < 2 or w < 2:
+        return error_dict("Image is too small. Minimum size 2x2"), 400
+    new_image = interface([new_image])[0]
+
+    scaled = False
+    if "scale" in params.keys() and params["scale"] != 100:
+        value = params["scale"]
+        new_image.thumbnail(
+            (int(image.size[0] * value / 100), int(image.size[1] * value / 100)),
+            resample=3,
+        )
+        scaled = True
+    if "crop" in params.keys():
+        value = params["crop"]
+        if value:
+            new_image = new_image.crop(new_image.getbbox())
+            if "crop_margin" in params.keys():
+                crop_margin = params["crop_margin"]
+                if "px" in crop_margin:
+                    crop_margin = crop_margin.replace("px", "")
+                    crop_margin = abs(int(crop_margin))
+                    if crop_margin > 500:
+                        return (
+                            error_dict(
+                                "The crop_margin cannot be larger than the original image size."
+                            ),
+                            400,
+                        )
+                    new_image = add_margin(
+                        new_image,
+                        crop_margin,
+                        crop_margin,
+                        crop_margin,
+                        crop_margin,
+                        (0, 0, 0, 0),
+                    )
+                elif "%" in crop_margin:
+                    crop_margin = crop_margin.replace("%", "")
+                    crop_margin = int(crop_margin)
+                    new_image = add_margin(
+                        new_image,
+                        int(new_image.size[1] * crop_margin / 100),
+                        int(new_image.size[0] * crop_margin / 100),
+                        int(new_image.size[1] * crop_margin / 100),
+                        int(new_image.size[0] * crop_margin / 100),
+                        (0, 0, 0, 0),
+                    )
+        else:
+            if "position" in params.keys() and scaled is False:
+                value = params["position"]
+                if len(value) == 2:
+                    new_image = transparency_paste(
+                        Image.new("RGBA", image.size),
+                        new_image,
+                        (
+                            int(image.size[0] * value[0] / 100),
+                            int(image.size[1] * value[1] / 100),
+                        ),
+                    )
+                else:
+                    new_image = transparency_paste(
+                        Image.new("RGBA", image.size), new_image, roi_box
+                    )
+            elif scaled is False:
+                new_image = transparency_paste(
+                    Image.new("RGBA", image.size), new_image, roi_box
+                )
+
+    if "channels" in params.keys():
+        value = params["channels"]
+        if value == "alpha":
+            new_image = extract_alpha_channel(new_image)
+        else:
+            bg_changed = False
+            if "bg_color" in params.keys():
+                value = params["bg_color"]
+                if len(value) > 0:
+                    color = ImageColor.getcolor(value, "RGB")
+                    bg = Image.new("RGBA", new_image.size, color)
+                    bg = transparency_paste(bg, new_image, (0, 0))
+                    new_image = bg.copy()
+                    bg_changed = True
+            if "bg_image_url" in params.keys() and bg_changed is False:
+                value = params["bg_image_url"]
+                if len(value) > 0:
+                    try:
+                        bg = Image.open(io.BytesIO(requests.get(value).content))
+                    except BaseException:
+                        return error_dict("Error download background image!"), 400
+                    bg = bg.resize(new_image.size)
+                    bg = bg.convert("RGBA")
+                    bg = transparency_paste(bg, new_image, (0, 0))
+                    new_image = bg.copy()
+                    bg_changed = True
+            if not is_json_or_www_encoded:
+                if bg and bg_changed is False:
+                    bg = bg.resize(new_image.size)
+                    bg = bg.convert("RGBA")
+                    bg = transparency_paste(bg, new_image, (0, 0))
+                    new_image = bg.copy()
+    if "format" in params.keys():
+        value = params["format"]
+        if value == "jpg":
+            new_image = new_image.convert("RGB")
+            img_io = io.BytesIO()
+            new_image.save(img_io, "JPEG", quality=100)
+            img_io.seek(0)
+            return {"type": "jpg", "data": [img_io, new_image.size]}
+        elif value == "zip":
+            mask = extract_alpha_channel(new_image)
+            mask_buff = io.BytesIO()
+            mask.save(mask_buff, "PNG")
+            mask_buff.seek(0)
+            image_buff = io.BytesIO()
+            image.save(image_buff, "JPEG")
+            image_buff.seek(0)
+            fileobj = io.BytesIO()
+            with zipfile.ZipFile(fileobj, "w") as zip_file:
+                zip_info = zipfile.ZipInfo(filename="color.jpg")
+                zip_info.date_time = time.localtime(time.time())[:6]
+                zip_info.compress_type = zipfile.ZIP_DEFLATED
+                zip_file.writestr(zip_info, image_buff.getvalue())
+                zip_info = zipfile.ZipInfo(filename="alpha.png")
+                zip_info.date_time = time.localtime(time.time())[:6]
+                zip_info.compress_type = zipfile.ZIP_DEFLATED
+                zip_file.writestr(zip_info, mask_buff.getvalue())
+            fileobj.seek(0)
+            return {"type": "zip", "data": [fileobj.read(), new_image.size]}
+        else:
+            buff = io.BytesIO()
+            new_image.save(buff, "PNG")
+            buff.seek(0)
+            return {"type": "png", "data": [buff, new_image.size]}
+    return (
+        error_dict(
+            "Something wrong with request or http api. Please, open new issue on Github! This is error in "
+            "code."
+        ),
+        400,
+    )
+
+
+
+
+
+
+
+

Functions

+
+
+def process_remove_bg(interface:Β Interface, params, image, bg, is_json_or_www_encoded=False) +
+
+

Handles a request to the removebg api method

+

Args

+
+
interface
+
CarveKit interface
+
bg
+
background pil image
+
is_json_or_www_encoded
+
is "json" or "x-www-form-urlencoded" content-type
+
image
+
foreground pil image
+
params
+
parameters
+
+
+ +Expand source code + +
def process_remove_bg(
+    interface: Interface, params, image, bg, is_json_or_www_encoded=False
+):
+    """
+    Handles a request to the removebg api method
+
+    Args:
+        interface: CarveKit interface
+        bg: background pil image
+        is_json_or_www_encoded: is "json" or "x-www-form-urlencoded" content-type
+        image: foreground pil image
+        params: parameters
+    """
+    h, w = image.size
+    if h < 2 or w < 2:
+        return error_dict("Image is too small. Minimum size 2x2"), 400
+
+    if "size" in params.keys():
+        value = params["size"]
+        if value == "preview" or value == "small" or value == "regular":
+            image.thumbnail((625, 400), resample=3)  # 0.25 mp
+        elif value == "medium":
+            image.thumbnail((1504, 1000), resample=3)  # 1.5 mp
+        elif value == "hd":
+            image.thumbnail((2000, 2000), resample=3)  # 2.5 mp
+        else:
+            image.thumbnail((6250, 4000), resample=3)  # 25 mp
+
+    roi_box = [0, 0, image.size[0], image.size[1]]
+    if "type" in params.keys():
+        value = params["type"]
+        pass
+
+    if "roi" in params.keys():
+        value = params["roi"].split(" ")
+        if len(value) == 4:
+            for i, coord in enumerate(value):
+                if "px" in coord:
+                    coord = coord.replace("px", "")
+                    try:
+                        coord = int(coord)
+                    except BaseException:
+                        return (
+                            error_dict(
+                                "Error converting roi coordinate string to number!"
+                            ),
+                            400,
+                        )
+                    if coord < 0:
+                        error_dict("Bad roi coordinate."), 400
+                    if (i == 0 or i == 2) and coord > image.size[0]:
+                        return (
+                            error_dict(
+                                "The roi coordinate cannot be larger than the image size."
+                            ),
+                            400,
+                        )
+                    elif (i == 1 or i == 3) and coord > image.size[1]:
+                        return (
+                            error_dict(
+                                "The roi coordinate cannot be larger than the image size."
+                            ),
+                            400,
+                        )
+                    roi_box[i] = int(coord)
+                elif "%" in coord:
+                    coord = coord.replace("%", "")
+                    try:
+                        coord = int(coord)
+                    except BaseException:
+                        return (
+                            error_dict(
+                                "Error converting roi coordinate string to number!"
+                            ),
+                            400,
+                        )
+                    if coord > 100:
+                        return (
+                            error_dict("The coordinate cannot be more than 100%"),
+                            400,
+                        )
+                    elif coord < 0:
+                        return error_dict("Coordinate cannot be less than 0%"), 400
+                    if i == 0 or i == 2:
+                        coord = int(image.size[0] * coord / 100)
+                    elif i == 1 or i == 3:
+                        coord = int(image.size[1] * coord / 100)
+                    roi_box[i] = coord
+                else:
+                    return error_dict("Something wrong with roi coordinates!"), 400
+
+    new_image = image.copy()
+    new_image = new_image.crop(roi_box)
+    h, w = new_image.size
+    if h < 2 or w < 2:
+        return error_dict("Image is too small. Minimum size 2x2"), 400
+    new_image = interface([new_image])[0]
+
+    scaled = False
+    if "scale" in params.keys() and params["scale"] != 100:
+        value = params["scale"]
+        new_image.thumbnail(
+            (int(image.size[0] * value / 100), int(image.size[1] * value / 100)),
+            resample=3,
+        )
+        scaled = True
+    if "crop" in params.keys():
+        value = params["crop"]
+        if value:
+            new_image = new_image.crop(new_image.getbbox())
+            if "crop_margin" in params.keys():
+                crop_margin = params["crop_margin"]
+                if "px" in crop_margin:
+                    crop_margin = crop_margin.replace("px", "")
+                    crop_margin = abs(int(crop_margin))
+                    if crop_margin > 500:
+                        return (
+                            error_dict(
+                                "The crop_margin cannot be larger than the original image size."
+                            ),
+                            400,
+                        )
+                    new_image = add_margin(
+                        new_image,
+                        crop_margin,
+                        crop_margin,
+                        crop_margin,
+                        crop_margin,
+                        (0, 0, 0, 0),
+                    )
+                elif "%" in crop_margin:
+                    crop_margin = crop_margin.replace("%", "")
+                    crop_margin = int(crop_margin)
+                    new_image = add_margin(
+                        new_image,
+                        int(new_image.size[1] * crop_margin / 100),
+                        int(new_image.size[0] * crop_margin / 100),
+                        int(new_image.size[1] * crop_margin / 100),
+                        int(new_image.size[0] * crop_margin / 100),
+                        (0, 0, 0, 0),
+                    )
+        else:
+            if "position" in params.keys() and scaled is False:
+                value = params["position"]
+                if len(value) == 2:
+                    new_image = transparency_paste(
+                        Image.new("RGBA", image.size),
+                        new_image,
+                        (
+                            int(image.size[0] * value[0] / 100),
+                            int(image.size[1] * value[1] / 100),
+                        ),
+                    )
+                else:
+                    new_image = transparency_paste(
+                        Image.new("RGBA", image.size), new_image, roi_box
+                    )
+            elif scaled is False:
+                new_image = transparency_paste(
+                    Image.new("RGBA", image.size), new_image, roi_box
+                )
+
+    if "channels" in params.keys():
+        value = params["channels"]
+        if value == "alpha":
+            new_image = extract_alpha_channel(new_image)
+        else:
+            bg_changed = False
+            if "bg_color" in params.keys():
+                value = params["bg_color"]
+                if len(value) > 0:
+                    color = ImageColor.getcolor(value, "RGB")
+                    bg = Image.new("RGBA", new_image.size, color)
+                    bg = transparency_paste(bg, new_image, (0, 0))
+                    new_image = bg.copy()
+                    bg_changed = True
+            if "bg_image_url" in params.keys() and bg_changed is False:
+                value = params["bg_image_url"]
+                if len(value) > 0:
+                    try:
+                        bg = Image.open(io.BytesIO(requests.get(value).content))
+                    except BaseException:
+                        return error_dict("Error download background image!"), 400
+                    bg = bg.resize(new_image.size)
+                    bg = bg.convert("RGBA")
+                    bg = transparency_paste(bg, new_image, (0, 0))
+                    new_image = bg.copy()
+                    bg_changed = True
+            if not is_json_or_www_encoded:
+                if bg and bg_changed is False:
+                    bg = bg.resize(new_image.size)
+                    bg = bg.convert("RGBA")
+                    bg = transparency_paste(bg, new_image, (0, 0))
+                    new_image = bg.copy()
+    if "format" in params.keys():
+        value = params["format"]
+        if value == "jpg":
+            new_image = new_image.convert("RGB")
+            img_io = io.BytesIO()
+            new_image.save(img_io, "JPEG", quality=100)
+            img_io.seek(0)
+            return {"type": "jpg", "data": [img_io, new_image.size]}
+        elif value == "zip":
+            mask = extract_alpha_channel(new_image)
+            mask_buff = io.BytesIO()
+            mask.save(mask_buff, "PNG")
+            mask_buff.seek(0)
+            image_buff = io.BytesIO()
+            image.save(image_buff, "JPEG")
+            image_buff.seek(0)
+            fileobj = io.BytesIO()
+            with zipfile.ZipFile(fileobj, "w") as zip_file:
+                zip_info = zipfile.ZipInfo(filename="color.jpg")
+                zip_info.date_time = time.localtime(time.time())[:6]
+                zip_info.compress_type = zipfile.ZIP_DEFLATED
+                zip_file.writestr(zip_info, image_buff.getvalue())
+                zip_info = zipfile.ZipInfo(filename="alpha.png")
+                zip_info.date_time = time.localtime(time.time())[:6]
+                zip_info.compress_type = zipfile.ZIP_DEFLATED
+                zip_file.writestr(zip_info, mask_buff.getvalue())
+            fileobj.seek(0)
+            return {"type": "zip", "data": [fileobj.read(), new_image.size]}
+        else:
+            buff = io.BytesIO()
+            new_image.save(buff, "PNG")
+            buff.seek(0)
+            return {"type": "png", "data": [buff, new_image.size]}
+    return (
+        error_dict(
+            "Something wrong with request or http api. Please, open new issue on Github! This is error in "
+            "code."
+        ),
+        400,
+    )
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/responses/api.html b/docs/api/carvekit/web/responses/api.html new file mode 100644 index 0000000..f1507e0 --- /dev/null +++ b/docs/api/carvekit/web/responses/api.html @@ -0,0 +1,95 @@ + + + + + + +carvekit.web.responses.api API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.responses.api

+
+
+
+ +Expand source code + +
def error_dict(error_text: str):
+    """
+    Generates a dictionary containing $error_text error
+    :param error_text: Error text
+    :return: error dictionary
+    """
+    resp = {"errors": [{"title": error_text}]}
+    return resp
+
+
+
+
+
+
+
+

Functions

+
+
+def error_dict(error_text:Β str) +
+
+

Generates a dictionary containing $error_text error +:param error_text: Error text +:return: error dictionary

+
+ +Expand source code + +
def error_dict(error_text: str):
+    """
+    Generates a dictionary containing $error_text error
+    :param error_text: Error text
+    :return: error dictionary
+    """
+    resp = {"errors": [{"title": error_text}]}
+    return resp
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/responses/index.html b/docs/api/carvekit/web/responses/index.html new file mode 100644 index 0000000..021527c --- /dev/null +++ b/docs/api/carvekit/web/responses/index.html @@ -0,0 +1,65 @@ + + + + + + +carvekit.web.responses API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.responses

+
+
+
+
+

Sub-modules

+
+
carvekit.web.responses.api
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/routers/api_router.html b/docs/api/carvekit/web/routers/api_router.html new file mode 100644 index 0000000..d99ca0e --- /dev/null +++ b/docs/api/carvekit/web/routers/api_router.html @@ -0,0 +1,517 @@ + + + + + + +carvekit.web.routers.api_router API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.routers.api_router

+
+
+
+ +Expand source code + +
import base64
+import http
+import io
+import time
+from json import JSONDecodeError
+from typing import Optional
+
+import requests
+from PIL import Image
+from fastapi import Header, Depends, Form, File, Request, APIRouter, UploadFile
+from fastapi.openapi.models import Response
+from pydantic import ValidationError
+from starlette.responses import JSONResponse
+
+from carvekit.web.deps import config, ml_processor
+from carvekit.web.handlers.response import handle_response, Authenticate
+from carvekit.web.responses.api import error_dict
+from carvekit.web.schemas.request import Parameters
+from carvekit.web.utils.net_utils import is_loopback
+
+api_router = APIRouter(prefix="", tags=["api"])
+
+
+# noinspection PyBroadException
+@api_router.post("/removebg")
+async def removebg(
+    request: Request,
+    image_file: Optional[bytes] = File(None),
+    auth: bool = Depends(Authenticate),
+    content_type: str = Header(""),
+    image_file_b64: Optional[str] = Form(None),
+    image_url: Optional[str] = Form(None),
+    bg_image_file: Optional[bytes] = File(None),
+    size: Optional[str] = Form("full"),
+    type: Optional[str] = Form("auto"),
+    format: Optional[str] = Form("auto"),
+    roi: str = Form("0% 0% 100% 100%"),
+    crop: bool = Form(False),
+    crop_margin: Optional[str] = Form("0px"),
+    scale: Optional[str] = Form("original"),
+    position: Optional[str] = Form("original"),
+    channels: Optional[str] = Form("rgba"),
+    add_shadow: bool = Form(False),  # Not supported at the moment
+    semitransparency: bool = Form(False),  # Not supported at the moment
+    bg_color: Optional[str] = Form(""),
+):
+    if auth is False:
+        return JSONResponse(content=error_dict("Missing API Key"), status_code=403)
+    if (
+        content_type not in ["application/x-www-form-urlencoded", "application/json"]
+        and "multipart/form-data" not in content_type
+    ):
+        return JSONResponse(
+            content=error_dict("Invalid request content type"), status_code=400
+        )
+
+    if image_url:
+        if not (
+            image_url.startswith("http://") or image_url.startswith("https://")
+        ) or is_loopback(image_url):
+            print(
+                f"Possible ssrf attempt to /api/removebg endpoint with image url: {image_url}"
+            )
+            return JSONResponse(
+                content=error_dict("Invalid image url."), status_code=400
+            )  # possible ssrf attempt
+
+    image = None
+    bg = None
+    parameters = None
+    if (
+        content_type == "application/x-www-form-urlencoded"
+        or "multipart/form-data" in content_type
+    ):
+        if image_file_b64 is None and image_url is None and image_file is None:
+            return JSONResponse(content=error_dict("File not found"), status_code=400)
+
+        if image_file_b64:
+            if len(image_file_b64) == 0:
+                return JSONResponse(content=error_dict("Empty image"), status_code=400)
+            try:
+                image = Image.open(io.BytesIO(base64.b64decode(image_file_b64)))
+            except BaseException:
+                return JSONResponse(
+                    content=error_dict("Error decode image!"), status_code=400
+                )
+        elif image_url:
+            try:
+                image = Image.open(io.BytesIO(requests.get(image_url).content))
+            except BaseException:
+                return JSONResponse(
+                    content=error_dict("Error download image!"), status_code=400
+                )
+        elif image_file:
+            if len(image_file) == 0:
+                return JSONResponse(content=error_dict("Empty image"), status_code=400)
+            image = Image.open(io.BytesIO(image_file))
+
+        if bg_image_file:
+            if len(bg_image_file) == 0:
+                return JSONResponse(content=error_dict("Empty image"), status_code=400)
+            bg = Image.open(io.BytesIO(bg_image_file))
+        try:
+            parameters = Parameters(
+                image_file_b64=image_file_b64,
+                image_url=image_url,
+                size=size,
+                type=type,
+                format=format,
+                roi=roi,
+                crop=crop,
+                crop_margin=crop_margin,
+                scale=scale,
+                position=position,
+                channels=channels,
+                add_shadow=add_shadow,
+                semitransparency=semitransparency,
+                bg_color=bg_color,
+            )
+        except ValidationError as e:
+            return JSONResponse(
+                content=e.json(), status_code=400, media_type="application/json"
+            )
+
+    else:
+        payload = None
+        try:
+            payload = await request.json()
+        except JSONDecodeError:
+            return JSONResponse(content=error_dict("Empty json"), status_code=400)
+        try:
+            parameters = Parameters(**payload)
+        except ValidationError as e:
+            return Response(
+                content=e.json(), status_code=400, media_type="application/json"
+            )
+        if parameters.image_file_b64 is None and parameters.image_url is None:
+            return JSONResponse(content=error_dict("File not found"), status_code=400)
+
+        if parameters.image_file_b64:
+            if len(parameters.image_file_b64) == 0:
+                return JSONResponse(content=error_dict("Empty image"), status_code=400)
+            try:
+                image = Image.open(
+                    io.BytesIO(base64.b64decode(parameters.image_file_b64))
+                )
+            except BaseException:
+                return JSONResponse(
+                    content=error_dict("Error decode image!"), status_code=400
+                )
+        elif parameters.image_url:
+            if not (
+                parameters.image_url.startswith("http://")
+                or parameters.image_url.startswith("https://")
+            ) or is_loopback(parameters.image_url):
+                print(
+                    f"Possible ssrf attempt to /api/removebg endpoint with image url: {parameters.image_url}"
+                )
+                return JSONResponse(
+                    content=error_dict("Invalid image url."), status_code=400
+                )  # possible ssrf attempt
+            try:
+                image = Image.open(
+                    io.BytesIO(requests.get(parameters.image_url).content)
+                )
+            except BaseException:
+                return JSONResponse(
+                    content=error_dict("Error download image!"), status_code=400
+                )
+        if image is None:
+            return JSONResponse(
+                content=error_dict("Error download image!"), status_code=400
+            )
+
+    job_id = ml_processor.job_create([parameters.dict(), image, bg, False])
+
+    while ml_processor.job_status(job_id) != "finished":
+        if ml_processor.job_status(job_id) == "not_found":
+            return JSONResponse(
+                content=error_dict("Job ID not found!"), status_code=500
+            )
+        time.sleep(5)
+
+    result = ml_processor.job_result(job_id)
+    return handle_response(result, image)
+
+
+@api_router.get("/account")
+def account():
+    """
+    Stub for compatibility with remove.bg api libraries
+    """
+    return JSONResponse(
+        content={
+            "data": {
+                "attributes": {
+                    "credits": {
+                        "total": 99999,
+                        "subscription": 99999,
+                        "payg": 99999,
+                        "enterprise": 99999,
+                    },
+                    "api": {"free_calls": 99999, "sizes": "all"},
+                }
+            }
+        },
+        status_code=200,
+    )
+
+
+@api_router.get("/admin/config")
+def status(auth: str = Depends(Authenticate)):
+    """
+    Returns the current server config.
+    """
+    if not auth or auth != "admin":
+        return JSONResponse(
+            content=error_dict("Authentication failed"), status_code=403
+        )
+    resp = JSONResponse(content=config.json(), status_code=200)
+    resp.headers["X-Credits-Charged"] = "0"
+    return resp
+
+
+
+
+
+
+
+

Functions

+
+
+def account() +
+
+

Stub for compatibility with remove.bg api libraries

+
+ +Expand source code + +
@api_router.get("/account")
+def account():
+    """
+    Stub for compatibility with remove.bg api libraries
+    """
+    return JSONResponse(
+        content={
+            "data": {
+                "attributes": {
+                    "credits": {
+                        "total": 99999,
+                        "subscription": 99999,
+                        "payg": 99999,
+                        "enterprise": 99999,
+                    },
+                    "api": {"free_calls": 99999, "sizes": "all"},
+                }
+            }
+        },
+        status_code=200,
+    )
+
+
+
+async def removebg(request:Β starlette.requests.Request, image_file:Β Optional[bytes]Β =Β File(None), auth:Β boolΒ =Β Depends(Authenticate), content_type:Β strΒ =Β Header(), image_file_b64:Β Optional[str]Β =Β Form(None), image_url:Β Optional[str]Β =Β Form(None), bg_image_file:Β Optional[bytes]Β =Β File(None), size:Β Optional[str]Β =Β Form(full), type:Β Optional[str]Β =Β Form(auto), format:Β Optional[str]Β =Β Form(auto), roi:Β strΒ =Β Form(0% 0% 100% 100%), crop:Β boolΒ =Β Form(False), crop_margin:Β Optional[str]Β =Β Form(0px), scale:Β Optional[str]Β =Β Form(original), position:Β Optional[str]Β =Β Form(original), channels:Β Optional[str]Β =Β Form(rgba), add_shadow:Β boolΒ =Β Form(False), semitransparency:Β boolΒ =Β Form(False), bg_color:Β Optional[str]Β =Β Form()) +
+
+
+
+ +Expand source code + +
@api_router.post("/removebg")
+async def removebg(
+    request: Request,
+    image_file: Optional[bytes] = File(None),
+    auth: bool = Depends(Authenticate),
+    content_type: str = Header(""),
+    image_file_b64: Optional[str] = Form(None),
+    image_url: Optional[str] = Form(None),
+    bg_image_file: Optional[bytes] = File(None),
+    size: Optional[str] = Form("full"),
+    type: Optional[str] = Form("auto"),
+    format: Optional[str] = Form("auto"),
+    roi: str = Form("0% 0% 100% 100%"),
+    crop: bool = Form(False),
+    crop_margin: Optional[str] = Form("0px"),
+    scale: Optional[str] = Form("original"),
+    position: Optional[str] = Form("original"),
+    channels: Optional[str] = Form("rgba"),
+    add_shadow: bool = Form(False),  # Not supported at the moment
+    semitransparency: bool = Form(False),  # Not supported at the moment
+    bg_color: Optional[str] = Form(""),
+):
+    if auth is False:
+        return JSONResponse(content=error_dict("Missing API Key"), status_code=403)
+    if (
+        content_type not in ["application/x-www-form-urlencoded", "application/json"]
+        and "multipart/form-data" not in content_type
+    ):
+        return JSONResponse(
+            content=error_dict("Invalid request content type"), status_code=400
+        )
+
+    if image_url:
+        if not (
+            image_url.startswith("http://") or image_url.startswith("https://")
+        ) or is_loopback(image_url):
+            print(
+                f"Possible ssrf attempt to /api/removebg endpoint with image url: {image_url}"
+            )
+            return JSONResponse(
+                content=error_dict("Invalid image url."), status_code=400
+            )  # possible ssrf attempt
+
+    image = None
+    bg = None
+    parameters = None
+    if (
+        content_type == "application/x-www-form-urlencoded"
+        or "multipart/form-data" in content_type
+    ):
+        if image_file_b64 is None and image_url is None and image_file is None:
+            return JSONResponse(content=error_dict("File not found"), status_code=400)
+
+        if image_file_b64:
+            if len(image_file_b64) == 0:
+                return JSONResponse(content=error_dict("Empty image"), status_code=400)
+            try:
+                image = Image.open(io.BytesIO(base64.b64decode(image_file_b64)))
+            except BaseException:
+                return JSONResponse(
+                    content=error_dict("Error decode image!"), status_code=400
+                )
+        elif image_url:
+            try:
+                image = Image.open(io.BytesIO(requests.get(image_url).content))
+            except BaseException:
+                return JSONResponse(
+                    content=error_dict("Error download image!"), status_code=400
+                )
+        elif image_file:
+            if len(image_file) == 0:
+                return JSONResponse(content=error_dict("Empty image"), status_code=400)
+            image = Image.open(io.BytesIO(image_file))
+
+        if bg_image_file:
+            if len(bg_image_file) == 0:
+                return JSONResponse(content=error_dict("Empty image"), status_code=400)
+            bg = Image.open(io.BytesIO(bg_image_file))
+        try:
+            parameters = Parameters(
+                image_file_b64=image_file_b64,
+                image_url=image_url,
+                size=size,
+                type=type,
+                format=format,
+                roi=roi,
+                crop=crop,
+                crop_margin=crop_margin,
+                scale=scale,
+                position=position,
+                channels=channels,
+                add_shadow=add_shadow,
+                semitransparency=semitransparency,
+                bg_color=bg_color,
+            )
+        except ValidationError as e:
+            return JSONResponse(
+                content=e.json(), status_code=400, media_type="application/json"
+            )
+
+    else:
+        payload = None
+        try:
+            payload = await request.json()
+        except JSONDecodeError:
+            return JSONResponse(content=error_dict("Empty json"), status_code=400)
+        try:
+            parameters = Parameters(**payload)
+        except ValidationError as e:
+            return Response(
+                content=e.json(), status_code=400, media_type="application/json"
+            )
+        if parameters.image_file_b64 is None and parameters.image_url is None:
+            return JSONResponse(content=error_dict("File not found"), status_code=400)
+
+        if parameters.image_file_b64:
+            if len(parameters.image_file_b64) == 0:
+                return JSONResponse(content=error_dict("Empty image"), status_code=400)
+            try:
+                image = Image.open(
+                    io.BytesIO(base64.b64decode(parameters.image_file_b64))
+                )
+            except BaseException:
+                return JSONResponse(
+                    content=error_dict("Error decode image!"), status_code=400
+                )
+        elif parameters.image_url:
+            if not (
+                parameters.image_url.startswith("http://")
+                or parameters.image_url.startswith("https://")
+            ) or is_loopback(parameters.image_url):
+                print(
+                    f"Possible ssrf attempt to /api/removebg endpoint with image url: {parameters.image_url}"
+                )
+                return JSONResponse(
+                    content=error_dict("Invalid image url."), status_code=400
+                )  # possible ssrf attempt
+            try:
+                image = Image.open(
+                    io.BytesIO(requests.get(parameters.image_url).content)
+                )
+            except BaseException:
+                return JSONResponse(
+                    content=error_dict("Error download image!"), status_code=400
+                )
+        if image is None:
+            return JSONResponse(
+                content=error_dict("Error download image!"), status_code=400
+            )
+
+    job_id = ml_processor.job_create([parameters.dict(), image, bg, False])
+
+    while ml_processor.job_status(job_id) != "finished":
+        if ml_processor.job_status(job_id) == "not_found":
+            return JSONResponse(
+                content=error_dict("Job ID not found!"), status_code=500
+            )
+        time.sleep(5)
+
+    result = ml_processor.job_result(job_id)
+    return handle_response(result, image)
+
+
+
+def status(auth:Β strΒ =Β Depends(Authenticate)) +
+
+

Returns the current server config.

+
+ +Expand source code + +
@api_router.get("/admin/config")
+def status(auth: str = Depends(Authenticate)):
+    """
+    Returns the current server config.
+    """
+    if not auth or auth != "admin":
+        return JSONResponse(
+            content=error_dict("Authentication failed"), status_code=403
+        )
+    resp = JSONResponse(content=config.json(), status_code=200)
+    resp.headers["X-Credits-Charged"] = "0"
+    return resp
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/routers/index.html b/docs/api/carvekit/web/routers/index.html new file mode 100644 index 0000000..1da91f5 --- /dev/null +++ b/docs/api/carvekit/web/routers/index.html @@ -0,0 +1,65 @@ + + + + + + +carvekit.web.routers API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.routers

+
+
+
+
+

Sub-modules

+
+
carvekit.web.routers.api_router
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/schemas/config.html b/docs/api/carvekit/web/schemas/config.html new file mode 100644 index 0000000..3e5285e --- /dev/null +++ b/docs/api/carvekit/web/schemas/config.html @@ -0,0 +1,550 @@ + + + + + + +carvekit.web.schemas.config API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.schemas.config

+
+
+
+ +Expand source code + +
import secrets
+from typing import List
+from typing_extensions import Literal
+
+import torch.cuda
+from pydantic import BaseModel, validator
+
+
+class AuthConfig(BaseModel):
+    """Config for web api token authentication"""
+
+    auth: bool = True
+    """Enables Token Authentication for API"""
+    admin_token: str = secrets.token_hex(32)
+    """Admin Token"""
+    allowed_tokens: List[str] = [secrets.token_hex(32)]
+    """All allowed tokens"""
+
+
+class MLConfig(BaseModel):
+    """Config for ml part of framework"""
+
+    segmentation_network: Literal[
+        "u2net", "deeplabv3", "basnet", "tracer_b7"
+    ] = "tracer_b7"
+    """Segmentation Network"""
+    preprocessing_method: Literal["none", "stub", "autoscene", "auto"] = "autoscene"
+    """Pre-processing Method"""
+    postprocessing_method: Literal["fba", "cascade_fba", "none"] = "cascade_fba"
+    """Post-Processing Network"""
+    device: str = "cpu"
+    """Processing device"""
+    batch_size_pre: int = 5
+    """Batch size for preprocessing method"""
+    batch_size_seg: int = 5
+    """Batch size for segmentation network"""
+    batch_size_matting: int = 1
+    """Batch size for matting network"""
+    batch_size_refine: int = 1
+    """Batch size for refine network"""
+    seg_mask_size: int = 640
+    """The size of the input image for the segmentation neural network."""
+    matting_mask_size: int = 2048
+    """The size of the input image for the matting neural network."""
+    refine_mask_size: int = 900
+    """The size of the input image for the refine neural network."""
+    fp16: bool = False
+    """Use half precision for inference"""
+    trimap_dilation: int = 30
+    """Dilation size for trimap"""
+    trimap_erosion: int = 5
+    """Erosion levels for trimap"""
+    trimap_prob_threshold: int = 231
+    """Probability threshold for trimap generation"""
+
+    @validator("seg_mask_size")
+    def seg_mask_size_validator(cls, value: int, values):
+        if value > 0:
+            return value
+        else:
+            raise ValueError("Incorrect seg_mask_size!")
+
+    @validator("matting_mask_size")
+    def matting_mask_size_validator(cls, value: int, values):
+        if value > 0:
+            return value
+        else:
+            raise ValueError("Incorrect matting_mask_size!")
+
+    @validator("batch_size_seg")
+    def batch_size_seg_validator(cls, value: int, values):
+        if value > 0:
+            return value
+        else:
+            raise ValueError("Incorrect batch size!")
+
+    @validator("batch_size_matting")
+    def batch_size_matting_validator(cls, value: int, values):
+        if value > 0:
+            return value
+        else:
+            raise ValueError("Incorrect batch size!")
+
+    @validator("device")
+    def device_validator(cls, value):
+        if torch.cuda.is_available() is False and "cuda" in value:
+            raise ValueError(
+                "GPU is not available, but specified as processing device!"
+            )
+        if "cuda" not in value and "cpu" != value:
+            raise ValueError("Unknown processing device! It should be cpu or cuda!")
+        return value
+
+
+class WebAPIConfig(BaseModel):
+    """FastAPI app config"""
+
+    port: int = 5000
+    """Web API port"""
+    host: str = "0.0.0.0"
+    """Web API host"""
+    ml: MLConfig = MLConfig()
+    """Config for ml part of framework"""
+    auth: AuthConfig = AuthConfig()
+    """Config for web api token authentication """
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class AuthConfig +(**data:Β Any) +
+
+

Config for web api token authentication

+

Create a new model by parsing and validating input data from keyword arguments.

+

Raises ValidationError if the input data cannot be parsed to form a valid model.

+
+ +Expand source code + +
class AuthConfig(BaseModel):
+    """Config for web api token authentication"""
+
+    auth: bool = True
+    """Enables Token Authentication for API"""
+    admin_token: str = secrets.token_hex(32)
+    """Admin Token"""
+    allowed_tokens: List[str] = [secrets.token_hex(32)]
+    """All allowed tokens"""
+
+

Ancestors

+
    +
  • pydantic.main.BaseModel
  • +
  • pydantic.utils.Representation
  • +
+

Class variables

+
+
var admin_token :Β str
+
+

Admin Token

+
+
var allowed_tokens :Β List[str]
+
+

All allowed tokens

+
+
var auth :Β bool
+
+

Enables Token Authentication for API

+
+
+
+
+class MLConfig +(**data:Β Any) +
+
+

Config for ml part of framework

+

Create a new model by parsing and validating input data from keyword arguments.

+

Raises ValidationError if the input data cannot be parsed to form a valid model.

+
+ +Expand source code + +
class MLConfig(BaseModel):
+    """Config for ml part of framework"""
+
+    segmentation_network: Literal[
+        "u2net", "deeplabv3", "basnet", "tracer_b7"
+    ] = "tracer_b7"
+    """Segmentation Network"""
+    preprocessing_method: Literal["none", "stub", "autoscene", "auto"] = "autoscene"
+    """Pre-processing Method"""
+    postprocessing_method: Literal["fba", "cascade_fba", "none"] = "cascade_fba"
+    """Post-Processing Network"""
+    device: str = "cpu"
+    """Processing device"""
+    batch_size_pre: int = 5
+    """Batch size for preprocessing method"""
+    batch_size_seg: int = 5
+    """Batch size for segmentation network"""
+    batch_size_matting: int = 1
+    """Batch size for matting network"""
+    batch_size_refine: int = 1
+    """Batch size for refine network"""
+    seg_mask_size: int = 640
+    """The size of the input image for the segmentation neural network."""
+    matting_mask_size: int = 2048
+    """The size of the input image for the matting neural network."""
+    refine_mask_size: int = 900
+    """The size of the input image for the refine neural network."""
+    fp16: bool = False
+    """Use half precision for inference"""
+    trimap_dilation: int = 30
+    """Dilation size for trimap"""
+    trimap_erosion: int = 5
+    """Erosion levels for trimap"""
+    trimap_prob_threshold: int = 231
+    """Probability threshold for trimap generation"""
+
+    @validator("seg_mask_size")
+    def seg_mask_size_validator(cls, value: int, values):
+        if value > 0:
+            return value
+        else:
+            raise ValueError("Incorrect seg_mask_size!")
+
+    @validator("matting_mask_size")
+    def matting_mask_size_validator(cls, value: int, values):
+        if value > 0:
+            return value
+        else:
+            raise ValueError("Incorrect matting_mask_size!")
+
+    @validator("batch_size_seg")
+    def batch_size_seg_validator(cls, value: int, values):
+        if value > 0:
+            return value
+        else:
+            raise ValueError("Incorrect batch size!")
+
+    @validator("batch_size_matting")
+    def batch_size_matting_validator(cls, value: int, values):
+        if value > 0:
+            return value
+        else:
+            raise ValueError("Incorrect batch size!")
+
+    @validator("device")
+    def device_validator(cls, value):
+        if torch.cuda.is_available() is False and "cuda" in value:
+            raise ValueError(
+                "GPU is not available, but specified as processing device!"
+            )
+        if "cuda" not in value and "cpu" != value:
+            raise ValueError("Unknown processing device! It should be cpu or cuda!")
+        return value
+
+

Ancestors

+
    +
  • pydantic.main.BaseModel
  • +
  • pydantic.utils.Representation
  • +
+

Class variables

+
+
var batch_size_matting :Β int
+
+

Batch size for matting network

+
+
var batch_size_pre :Β int
+
+

Batch size for preprocessing method

+
+
var batch_size_refine :Β int
+
+

Batch size for refine network

+
+
var batch_size_seg :Β int
+
+

Batch size for segmentation network

+
+
var device :Β str
+
+

Processing device

+
+
var fp16 :Β bool
+
+

Use half precision for inference

+
+
var matting_mask_size :Β int
+
+

The size of the input image for the matting neural network.

+
+
var postprocessing_method :Β Literal['fba',Β 'cascade_fba',Β 'none']
+
+

Post-Processing Network

+
+
var preprocessing_method :Β Literal['none',Β 'stub',Β 'autoscene',Β 'auto']
+
+

Pre-processing Method

+
+
var refine_mask_size :Β int
+
+

The size of the input image for the refine neural network.

+
+
var seg_mask_size :Β int
+
+

The size of the input image for the segmentation neural network.

+
+
var segmentation_network :Β Literal['u2net',Β 'deeplabv3',Β 'basnet',Β 'tracer_b7']
+
+

Segmentation Network

+
+
var trimap_dilation :Β int
+
+

Dilation size for trimap

+
+
var trimap_erosion :Β int
+
+

Erosion levels for trimap

+
+
var trimap_prob_threshold :Β int
+
+

Probability threshold for trimap generation

+
+
+

Static methods

+
+
+def batch_size_matting_validator(value:Β int, values) +
+
+
+
+ +Expand source code + +
@validator("batch_size_matting")
+def batch_size_matting_validator(cls, value: int, values):
+    if value > 0:
+        return value
+    else:
+        raise ValueError("Incorrect batch size!")
+
+
+
+def batch_size_seg_validator(value:Β int, values) +
+
+
+
+ +Expand source code + +
@validator("batch_size_seg")
+def batch_size_seg_validator(cls, value: int, values):
+    if value > 0:
+        return value
+    else:
+        raise ValueError("Incorrect batch size!")
+
+
+
+def device_validator(value) +
+
+
+
+ +Expand source code + +
@validator("device")
+def device_validator(cls, value):
+    if torch.cuda.is_available() is False and "cuda" in value:
+        raise ValueError(
+            "GPU is not available, but specified as processing device!"
+        )
+    if "cuda" not in value and "cpu" != value:
+        raise ValueError("Unknown processing device! It should be cpu or cuda!")
+    return value
+
+
+
+def matting_mask_size_validator(value:Β int, values) +
+
+
+
+ +Expand source code + +
@validator("matting_mask_size")
+def matting_mask_size_validator(cls, value: int, values):
+    if value > 0:
+        return value
+    else:
+        raise ValueError("Incorrect matting_mask_size!")
+
+
+
+def seg_mask_size_validator(value:Β int, values) +
+
+
+
+ +Expand source code + +
@validator("seg_mask_size")
+def seg_mask_size_validator(cls, value: int, values):
+    if value > 0:
+        return value
+    else:
+        raise ValueError("Incorrect seg_mask_size!")
+
+
+
+
+
+class WebAPIConfig +(**data:Β Any) +
+
+

FastAPI app config

+

Create a new model by parsing and validating input data from keyword arguments.

+

Raises ValidationError if the input data cannot be parsed to form a valid model.

+
+ +Expand source code + +
class WebAPIConfig(BaseModel):
+    """FastAPI app config"""
+
+    port: int = 5000
+    """Web API port"""
+    host: str = "0.0.0.0"
+    """Web API host"""
+    ml: MLConfig = MLConfig()
+    """Config for ml part of framework"""
+    auth: AuthConfig = AuthConfig()
+    """Config for web api token authentication """
+
+

Ancestors

+
    +
  • pydantic.main.BaseModel
  • +
  • pydantic.utils.Representation
  • +
+

Class variables

+
+
var auth :Β AuthConfig
+
+

Config for web api token authentication

+
+
var host :Β str
+
+

Web API host

+
+
var ml :Β MLConfig
+
+

Config for ml part of framework

+
+
var port :Β int
+
+

Web API port

+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/schemas/index.html b/docs/api/carvekit/web/schemas/index.html new file mode 100644 index 0000000..20ae6cd --- /dev/null +++ b/docs/api/carvekit/web/schemas/index.html @@ -0,0 +1,70 @@ + + + + + + +carvekit.web.schemas API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.schemas

+
+
+
+
+

Sub-modules

+
+
carvekit.web.schemas.config
+
+
+
+
carvekit.web.schemas.request
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/schemas/request.html b/docs/api/carvekit/web/schemas/request.html new file mode 100644 index 0000000..faaa723 --- /dev/null +++ b/docs/api/carvekit/web/schemas/request.html @@ -0,0 +1,397 @@ + + + + + + +carvekit.web.schemas.request API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.schemas.request

+
+
+
+ +Expand source code + +
import re
+from typing import Optional
+
+from pydantic import BaseModel, validator
+from typing_extensions import Literal
+
+
+class Parameters(BaseModel):
+    image_file_b64: Optional[str] = ""
+    image_url: Optional[str] = ""
+    size: Optional[Literal["preview", "full", "auto"]] = "preview"
+    type: Optional[
+        Literal["auto", "product", "person", "car"]
+    ] = "auto"  # Not supported at the moment
+    format: Optional[Literal["auto", "jpg", "png", "zip"]] = "auto"
+    roi: str = "0% 0% 100% 100%"
+    crop: bool = False
+    crop_margin: Optional[str] = "0px"
+    scale: Optional[str] = "original"
+    position: Optional[str] = "original"
+    channels: Optional[Literal["rgba", "alpha"]] = "rgba"
+    add_shadow: str = "false"  # Not supported at the moment
+    semitransparency: str = "false"  # Not supported at the moment
+    bg_color: Optional[str] = ""
+    bg_image_url: Optional[str] = ""
+
+    @validator("crop_margin")
+    def crop_margin_validator(cls, value):
+        if not re.match(r"[0-9]+(px|%)$", value):
+            raise ValueError(
+                "crop_margin paramter is not valid"
+            )  # TODO: Add support of several values
+        if "%" in value and (int(value[:-1]) < 0 or int(value[:-1]) > 100):
+            raise ValueError("crop_margin mast be in range between 0% and 100%")
+        return value
+
+    @validator("scale")
+    def scale_validator(cls, value):
+        if value != "original" and (
+            not re.match(r"[0-9]+%$", value)
+            or not int(value[:-1]) <= 100
+            or not int(value[:-1]) >= 10
+        ):
+            raise ValueError("scale must be original or in between of 10% and 100%")
+
+        if value == "original":
+            return 100
+
+        return int(value[:-1])
+
+    @validator("position")
+    def position_validator(cls, value, values):
+        if len(value.split(" ")) > 2:
+            raise ValueError(
+                "Position must be a value from 0 to 100 "
+                "for both vertical and horizontal axises or for both axises respectively"
+            )
+
+        if value == "original":
+            return "original"
+        elif len(value.split(" ")) == 1:
+            return [int(value[:-1]), int(value[:-1])]
+        else:
+            return [int(value.split(" ")[0][:-1]), int(value.split(" ")[1][:-1])]
+
+    @validator("bg_color")
+    def bg_color_validator(cls, value):
+        if not re.match(r"(#{0,1}[0-9a-f]{3}){0,2}$", value):
+            raise ValueError("bg_color is not in hex")
+        if len(value) and value[0] != "#":
+            value = "#" + value
+        return value
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class Parameters +(**data:Β Any) +
+
+

Create a new model by parsing and validating input data from keyword arguments.

+

Raises ValidationError if the input data cannot be parsed to form a valid model.

+
+ +Expand source code + +
class Parameters(BaseModel):
+    image_file_b64: Optional[str] = ""
+    image_url: Optional[str] = ""
+    size: Optional[Literal["preview", "full", "auto"]] = "preview"
+    type: Optional[
+        Literal["auto", "product", "person", "car"]
+    ] = "auto"  # Not supported at the moment
+    format: Optional[Literal["auto", "jpg", "png", "zip"]] = "auto"
+    roi: str = "0% 0% 100% 100%"
+    crop: bool = False
+    crop_margin: Optional[str] = "0px"
+    scale: Optional[str] = "original"
+    position: Optional[str] = "original"
+    channels: Optional[Literal["rgba", "alpha"]] = "rgba"
+    add_shadow: str = "false"  # Not supported at the moment
+    semitransparency: str = "false"  # Not supported at the moment
+    bg_color: Optional[str] = ""
+    bg_image_url: Optional[str] = ""
+
+    @validator("crop_margin")
+    def crop_margin_validator(cls, value):
+        if not re.match(r"[0-9]+(px|%)$", value):
+            raise ValueError(
+                "crop_margin paramter is not valid"
+            )  # TODO: Add support of several values
+        if "%" in value and (int(value[:-1]) < 0 or int(value[:-1]) > 100):
+            raise ValueError("crop_margin mast be in range between 0% and 100%")
+        return value
+
+    @validator("scale")
+    def scale_validator(cls, value):
+        if value != "original" and (
+            not re.match(r"[0-9]+%$", value)
+            or not int(value[:-1]) <= 100
+            or not int(value[:-1]) >= 10
+        ):
+            raise ValueError("scale must be original or in between of 10% and 100%")
+
+        if value == "original":
+            return 100
+
+        return int(value[:-1])
+
+    @validator("position")
+    def position_validator(cls, value, values):
+        if len(value.split(" ")) > 2:
+            raise ValueError(
+                "Position must be a value from 0 to 100 "
+                "for both vertical and horizontal axises or for both axises respectively"
+            )
+
+        if value == "original":
+            return "original"
+        elif len(value.split(" ")) == 1:
+            return [int(value[:-1]), int(value[:-1])]
+        else:
+            return [int(value.split(" ")[0][:-1]), int(value.split(" ")[1][:-1])]
+
+    @validator("bg_color")
+    def bg_color_validator(cls, value):
+        if not re.match(r"(#{0,1}[0-9a-f]{3}){0,2}$", value):
+            raise ValueError("bg_color is not in hex")
+        if len(value) and value[0] != "#":
+            value = "#" + value
+        return value
+
+

Ancestors

+
    +
  • pydantic.main.BaseModel
  • +
  • pydantic.utils.Representation
  • +
+

Class variables

+
+
var add_shadow :Β str
+
+
+
+
var bg_color :Β Optional[str]
+
+
+
+
var bg_image_url :Β Optional[str]
+
+
+
+
var channels :Β Optional[Literal['rgba',Β 'alpha']]
+
+
+
+
var crop :Β bool
+
+
+
+
var crop_margin :Β Optional[str]
+
+
+
+
var format :Β Optional[Literal['auto',Β 'jpg',Β 'png',Β 'zip']]
+
+
+
+
var image_file_b64 :Β Optional[str]
+
+
+
+
var image_url :Β Optional[str]
+
+
+
+
var position :Β Optional[str]
+
+
+
+
var roi :Β str
+
+
+
+
var scale :Β Optional[str]
+
+
+
+
var semitransparency :Β str
+
+
+
+
var size :Β Optional[Literal['preview',Β 'full',Β 'auto']]
+
+
+
+
var type :Β Optional[Literal['auto',Β 'product',Β 'person',Β 'car']]
+
+
+
+
+

Static methods

+
+
+def bg_color_validator(value) +
+
+
+
+ +Expand source code + +
@validator("bg_color")
+def bg_color_validator(cls, value):
+    if not re.match(r"(#{0,1}[0-9a-f]{3}){0,2}$", value):
+        raise ValueError("bg_color is not in hex")
+    if len(value) and value[0] != "#":
+        value = "#" + value
+    return value
+
+
+
+def crop_margin_validator(value) +
+
+
+
+ +Expand source code + +
@validator("crop_margin")
+def crop_margin_validator(cls, value):
+    if not re.match(r"[0-9]+(px|%)$", value):
+        raise ValueError(
+            "crop_margin paramter is not valid"
+        )  # TODO: Add support of several values
+    if "%" in value and (int(value[:-1]) < 0 or int(value[:-1]) > 100):
+        raise ValueError("crop_margin mast be in range between 0% and 100%")
+    return value
+
+
+
+def position_validator(value, values) +
+
+
+
+ +Expand source code + +
@validator("position")
+def position_validator(cls, value, values):
+    if len(value.split(" ")) > 2:
+        raise ValueError(
+            "Position must be a value from 0 to 100 "
+            "for both vertical and horizontal axises or for both axises respectively"
+        )
+
+    if value == "original":
+        return "original"
+    elif len(value.split(" ")) == 1:
+        return [int(value[:-1]), int(value[:-1])]
+    else:
+        return [int(value.split(" ")[0][:-1]), int(value.split(" ")[1][:-1])]
+
+
+
+def scale_validator(value) +
+
+
+
+ +Expand source code + +
@validator("scale")
+def scale_validator(cls, value):
+    if value != "original" and (
+        not re.match(r"[0-9]+%$", value)
+        or not int(value[:-1]) <= 100
+        or not int(value[:-1]) >= 10
+    ):
+        raise ValueError("scale must be original or in between of 10% and 100%")
+
+    if value == "original":
+        return 100
+
+    return int(value[:-1])
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/utils/index.html b/docs/api/carvekit/web/utils/index.html new file mode 100644 index 0000000..94dd635 --- /dev/null +++ b/docs/api/carvekit/web/utils/index.html @@ -0,0 +1,75 @@ + + + + + + +carvekit.web.utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.utils

+
+
+
+
+

Sub-modules

+
+
carvekit.web.utils.init_utils
+
+
+
+
carvekit.web.utils.net_utils
+
+
+
+
carvekit.web.utils.task_queue
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/utils/init_utils.html b/docs/api/carvekit/web/utils/init_utils.html new file mode 100644 index 0000000..da42da3 --- /dev/null +++ b/docs/api/carvekit/web/utils/init_utils.html @@ -0,0 +1,551 @@ + + + + + + +carvekit.web.utils.init_utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.utils.init_utils

+
+
+
+ +Expand source code + +
import warnings
+from os import getenv
+from typing import Union
+
+from loguru import logger
+
+from carvekit.ml.wrap.cascadepsp import CascadePSP
+from carvekit.ml.wrap.scene_classifier import SceneClassifier
+from carvekit.web.schemas.config import WebAPIConfig, MLConfig, AuthConfig
+
+from carvekit.api.interface import Interface
+from carvekit.api.autointerface import AutoInterface
+
+from carvekit.ml.wrap.fba_matting import FBAMatting
+from carvekit.ml.wrap.u2net import U2NET
+from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
+from carvekit.ml.wrap.basnet import BASNET
+from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
+from carvekit.ml.wrap.yolov4 import SimplifiedYoloV4
+
+
+from carvekit.pipelines.postprocessing import MattingMethod, CasMattingMethod
+from carvekit.pipelines.preprocessing import PreprocessingStub, AutoScene
+from carvekit.trimap.generator import TrimapGenerator
+
+
+def init_config() -> WebAPIConfig:
+    default_config = WebAPIConfig()
+    config = WebAPIConfig(
+        **dict(
+            port=int(getenv("CARVEKIT_PORT", default_config.port)),
+            host=getenv("CARVEKIT_HOST", default_config.host),
+            ml=MLConfig(
+                segmentation_network=getenv(
+                    "CARVEKIT_SEGMENTATION_NETWORK",
+                    default_config.ml.segmentation_network,
+                ),
+                preprocessing_method=getenv(
+                    "CARVEKIT_PREPROCESSING_METHOD",
+                    default_config.ml.preprocessing_method,
+                ),
+                postprocessing_method=getenv(
+                    "CARVEKIT_POSTPROCESSING_METHOD",
+                    default_config.ml.postprocessing_method,
+                ),
+                device=getenv("CARVEKIT_DEVICE", default_config.ml.device),
+                batch_size_pre=int(
+                    getenv("CARVEKIT_BATCH_SIZE_PRE", default_config.ml.batch_size_pre)
+                ),
+                batch_size_seg=int(
+                    getenv("CARVEKIT_BATCH_SIZE_SEG", default_config.ml.batch_size_seg)
+                ),
+                batch_size_matting=int(
+                    getenv(
+                        "CARVEKIT_BATCH_SIZE_MATTING",
+                        default_config.ml.batch_size_matting,
+                    )
+                ),
+                batch_size_refine=int(
+                    getenv(
+                        "CARVEKIT_BATCH_SIZE_REFINE",
+                        default_config.ml.batch_size_refine,
+                    )
+                ),
+                seg_mask_size=int(
+                    getenv("CARVEKIT_SEG_MASK_SIZE", default_config.ml.seg_mask_size)
+                ),
+                matting_mask_size=int(
+                    getenv(
+                        "CARVEKIT_MATTING_MASK_SIZE",
+                        default_config.ml.matting_mask_size,
+                    )
+                ),
+                refine_mask_size=int(
+                    getenv(
+                        "CARVEKIT_REFINE_MASK_SIZE",
+                        default_config.ml.refine_mask_size,
+                    )
+                ),
+                fp16=bool(int(getenv("CARVEKIT_FP16", default_config.ml.fp16))),
+                trimap_prob_threshold=int(
+                    getenv(
+                        "CARVEKIT_TRIMAP_PROB_THRESHOLD",
+                        default_config.ml.trimap_prob_threshold,
+                    )
+                ),
+                trimap_dilation=int(
+                    getenv(
+                        "CARVEKIT_TRIMAP_DILATION", default_config.ml.trimap_dilation
+                    )
+                ),
+                trimap_erosion=int(
+                    getenv("CARVEKIT_TRIMAP_EROSION", default_config.ml.trimap_erosion)
+                ),
+            ),
+            auth=AuthConfig(
+                auth=bool(
+                    int(getenv("CARVEKIT_AUTH_ENABLE", default_config.auth.auth))
+                ),
+                admin_token=getenv(
+                    "CARVEKIT_ADMIN_TOKEN", default_config.auth.admin_token
+                ),
+                allowed_tokens=default_config.auth.allowed_tokens
+                if getenv("CARVEKIT_ALLOWED_TOKENS") is None
+                else getenv("CARVEKIT_ALLOWED_TOKENS").split(","),
+            ),
+        )
+    )
+
+    logger.info(f"Admin token for Web API is {config.auth.admin_token}")
+    logger.debug(f"Running Web API with this config: {config.json()}")
+    return config
+
+
+def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface:
+    if isinstance(config, WebAPIConfig):
+        config = config.ml
+    if config.preprocessing_method == "auto":
+        warnings.warn(
+            "Preprocessing_method is set to `auto`."
+            "We will use automatic methods to determine the best methods for your images! "
+            "Please note that this is not always the best option and all other options will be ignored!"
+        )
+        scene_classifier = SceneClassifier(
+            device=config.device, batch_size=config.batch_size_pre, fp16=config.fp16
+        )
+        object_classifier = SimplifiedYoloV4(
+            device=config.device, batch_size=config.batch_size_pre, fp16=config.fp16
+        )
+        return AutoInterface(
+            scene_classifier=scene_classifier,
+            object_classifier=object_classifier,
+            segmentation_batch_size=config.batch_size_seg,
+            postprocessing_batch_size=config.batch_size_matting,
+            postprocessing_image_size=config.matting_mask_size,
+            segmentation_device=config.device,
+            postprocessing_device=config.device,
+            fp16=config.fp16,
+        )
+
+    else:
+        if config.segmentation_network == "u2net":
+            seg_net = U2NET(
+                device=config.device,
+                batch_size=config.batch_size_seg,
+                input_image_size=config.seg_mask_size,
+                fp16=config.fp16,
+            )
+        elif config.segmentation_network == "deeplabv3":
+            seg_net = DeepLabV3(
+                device=config.device,
+                batch_size=config.batch_size_seg,
+                input_image_size=config.seg_mask_size,
+                fp16=config.fp16,
+            )
+        elif config.segmentation_network == "basnet":
+            seg_net = BASNET(
+                device=config.device,
+                batch_size=config.batch_size_seg,
+                input_image_size=config.seg_mask_size,
+                fp16=config.fp16,
+            )
+        elif config.segmentation_network == "tracer_b7":
+            seg_net = TracerUniversalB7(
+                device=config.device,
+                batch_size=config.batch_size_seg,
+                input_image_size=config.seg_mask_size,
+                fp16=config.fp16,
+            )
+        else:
+            seg_net = TracerUniversalB7(
+                device=config.device,
+                batch_size=config.batch_size_seg,
+                input_image_size=config.seg_mask_size,
+                fp16=config.fp16,
+            )
+
+        if config.preprocessing_method == "stub":
+            preprocessing = PreprocessingStub()
+        elif config.preprocessing_method == "none":
+            preprocessing = None
+        elif config.preprocessing_method == "autoscene":
+            preprocessing = AutoScene(
+                scene_classifier=SceneClassifier(
+                    device=config.device,
+                    batch_size=config.batch_size_pre,
+                    fp16=config.fp16,
+                )
+            )
+        else:
+            preprocessing = None
+
+        if config.postprocessing_method == "fba":
+            fba = FBAMatting(
+                device=config.device,
+                batch_size=config.batch_size_matting,
+                input_tensor_size=config.matting_mask_size,
+                fp16=config.fp16,
+            )
+            trimap_generator = TrimapGenerator(
+                prob_threshold=config.trimap_prob_threshold,
+                kernel_size=config.trimap_dilation,
+                erosion_iters=config.trimap_erosion,
+            )
+            postprocessing = MattingMethod(
+                device=config.device,
+                matting_module=fba,
+                trimap_generator=trimap_generator,
+            )
+        elif config.postprocessing_method == "cascade_fba":
+            cascadepsp = CascadePSP(
+                device=config.device,
+                batch_size=config.batch_size_refine,
+                input_tensor_size=config.refine_mask_size,
+                fp16=config.fp16,
+            )
+            fba = FBAMatting(
+                device=config.device,
+                batch_size=config.batch_size_matting,
+                input_tensor_size=config.matting_mask_size,
+                fp16=config.fp16,
+            )
+            trimap_generator = TrimapGenerator(
+                prob_threshold=config.trimap_prob_threshold,
+                kernel_size=config.trimap_dilation,
+                erosion_iters=config.trimap_erosion,
+            )
+            postprocessing = CasMattingMethod(
+                device=config.device,
+                matting_module=fba,
+                trimap_generator=trimap_generator,
+                refining_module=cascadepsp,
+            )
+        elif config.postprocessing_method == "none":
+            postprocessing = None
+        else:
+            postprocessing = None
+
+        interface = Interface(
+            pre_pipe=preprocessing,
+            post_pipe=postprocessing,
+            seg_pipe=seg_net,
+            device=config.device,
+        )
+    return interface
+
+
+
+
+
+
+
+

Functions

+
+
+def init_config() ‑>Β WebAPIConfig +
+
+
+
+ +Expand source code + +
def init_config() -> WebAPIConfig:
+    default_config = WebAPIConfig()
+    config = WebAPIConfig(
+        **dict(
+            port=int(getenv("CARVEKIT_PORT", default_config.port)),
+            host=getenv("CARVEKIT_HOST", default_config.host),
+            ml=MLConfig(
+                segmentation_network=getenv(
+                    "CARVEKIT_SEGMENTATION_NETWORK",
+                    default_config.ml.segmentation_network,
+                ),
+                preprocessing_method=getenv(
+                    "CARVEKIT_PREPROCESSING_METHOD",
+                    default_config.ml.preprocessing_method,
+                ),
+                postprocessing_method=getenv(
+                    "CARVEKIT_POSTPROCESSING_METHOD",
+                    default_config.ml.postprocessing_method,
+                ),
+                device=getenv("CARVEKIT_DEVICE", default_config.ml.device),
+                batch_size_pre=int(
+                    getenv("CARVEKIT_BATCH_SIZE_PRE", default_config.ml.batch_size_pre)
+                ),
+                batch_size_seg=int(
+                    getenv("CARVEKIT_BATCH_SIZE_SEG", default_config.ml.batch_size_seg)
+                ),
+                batch_size_matting=int(
+                    getenv(
+                        "CARVEKIT_BATCH_SIZE_MATTING",
+                        default_config.ml.batch_size_matting,
+                    )
+                ),
+                batch_size_refine=int(
+                    getenv(
+                        "CARVEKIT_BATCH_SIZE_REFINE",
+                        default_config.ml.batch_size_refine,
+                    )
+                ),
+                seg_mask_size=int(
+                    getenv("CARVEKIT_SEG_MASK_SIZE", default_config.ml.seg_mask_size)
+                ),
+                matting_mask_size=int(
+                    getenv(
+                        "CARVEKIT_MATTING_MASK_SIZE",
+                        default_config.ml.matting_mask_size,
+                    )
+                ),
+                refine_mask_size=int(
+                    getenv(
+                        "CARVEKIT_REFINE_MASK_SIZE",
+                        default_config.ml.refine_mask_size,
+                    )
+                ),
+                fp16=bool(int(getenv("CARVEKIT_FP16", default_config.ml.fp16))),
+                trimap_prob_threshold=int(
+                    getenv(
+                        "CARVEKIT_TRIMAP_PROB_THRESHOLD",
+                        default_config.ml.trimap_prob_threshold,
+                    )
+                ),
+                trimap_dilation=int(
+                    getenv(
+                        "CARVEKIT_TRIMAP_DILATION", default_config.ml.trimap_dilation
+                    )
+                ),
+                trimap_erosion=int(
+                    getenv("CARVEKIT_TRIMAP_EROSION", default_config.ml.trimap_erosion)
+                ),
+            ),
+            auth=AuthConfig(
+                auth=bool(
+                    int(getenv("CARVEKIT_AUTH_ENABLE", default_config.auth.auth))
+                ),
+                admin_token=getenv(
+                    "CARVEKIT_ADMIN_TOKEN", default_config.auth.admin_token
+                ),
+                allowed_tokens=default_config.auth.allowed_tokens
+                if getenv("CARVEKIT_ALLOWED_TOKENS") is None
+                else getenv("CARVEKIT_ALLOWED_TOKENS").split(","),
+            ),
+        )
+    )
+
+    logger.info(f"Admin token for Web API is {config.auth.admin_token}")
+    logger.debug(f"Running Web API with this config: {config.json()}")
+    return config
+
+
+
+def init_interface(config:Β Union[WebAPIConfig,Β MLConfig]) ‑>Β Interface +
+
+
+
+ +Expand source code + +
def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface:
+    if isinstance(config, WebAPIConfig):
+        config = config.ml
+    if config.preprocessing_method == "auto":
+        warnings.warn(
+            "Preprocessing_method is set to `auto`."
+            "We will use automatic methods to determine the best methods for your images! "
+            "Please note that this is not always the best option and all other options will be ignored!"
+        )
+        scene_classifier = SceneClassifier(
+            device=config.device, batch_size=config.batch_size_pre, fp16=config.fp16
+        )
+        object_classifier = SimplifiedYoloV4(
+            device=config.device, batch_size=config.batch_size_pre, fp16=config.fp16
+        )
+        return AutoInterface(
+            scene_classifier=scene_classifier,
+            object_classifier=object_classifier,
+            segmentation_batch_size=config.batch_size_seg,
+            postprocessing_batch_size=config.batch_size_matting,
+            postprocessing_image_size=config.matting_mask_size,
+            segmentation_device=config.device,
+            postprocessing_device=config.device,
+            fp16=config.fp16,
+        )
+
+    else:
+        if config.segmentation_network == "u2net":
+            seg_net = U2NET(
+                device=config.device,
+                batch_size=config.batch_size_seg,
+                input_image_size=config.seg_mask_size,
+                fp16=config.fp16,
+            )
+        elif config.segmentation_network == "deeplabv3":
+            seg_net = DeepLabV3(
+                device=config.device,
+                batch_size=config.batch_size_seg,
+                input_image_size=config.seg_mask_size,
+                fp16=config.fp16,
+            )
+        elif config.segmentation_network == "basnet":
+            seg_net = BASNET(
+                device=config.device,
+                batch_size=config.batch_size_seg,
+                input_image_size=config.seg_mask_size,
+                fp16=config.fp16,
+            )
+        elif config.segmentation_network == "tracer_b7":
+            seg_net = TracerUniversalB7(
+                device=config.device,
+                batch_size=config.batch_size_seg,
+                input_image_size=config.seg_mask_size,
+                fp16=config.fp16,
+            )
+        else:
+            seg_net = TracerUniversalB7(
+                device=config.device,
+                batch_size=config.batch_size_seg,
+                input_image_size=config.seg_mask_size,
+                fp16=config.fp16,
+            )
+
+        if config.preprocessing_method == "stub":
+            preprocessing = PreprocessingStub()
+        elif config.preprocessing_method == "none":
+            preprocessing = None
+        elif config.preprocessing_method == "autoscene":
+            preprocessing = AutoScene(
+                scene_classifier=SceneClassifier(
+                    device=config.device,
+                    batch_size=config.batch_size_pre,
+                    fp16=config.fp16,
+                )
+            )
+        else:
+            preprocessing = None
+
+        if config.postprocessing_method == "fba":
+            fba = FBAMatting(
+                device=config.device,
+                batch_size=config.batch_size_matting,
+                input_tensor_size=config.matting_mask_size,
+                fp16=config.fp16,
+            )
+            trimap_generator = TrimapGenerator(
+                prob_threshold=config.trimap_prob_threshold,
+                kernel_size=config.trimap_dilation,
+                erosion_iters=config.trimap_erosion,
+            )
+            postprocessing = MattingMethod(
+                device=config.device,
+                matting_module=fba,
+                trimap_generator=trimap_generator,
+            )
+        elif config.postprocessing_method == "cascade_fba":
+            cascadepsp = CascadePSP(
+                device=config.device,
+                batch_size=config.batch_size_refine,
+                input_tensor_size=config.refine_mask_size,
+                fp16=config.fp16,
+            )
+            fba = FBAMatting(
+                device=config.device,
+                batch_size=config.batch_size_matting,
+                input_tensor_size=config.matting_mask_size,
+                fp16=config.fp16,
+            )
+            trimap_generator = TrimapGenerator(
+                prob_threshold=config.trimap_prob_threshold,
+                kernel_size=config.trimap_dilation,
+                erosion_iters=config.trimap_erosion,
+            )
+            postprocessing = CasMattingMethod(
+                device=config.device,
+                matting_module=fba,
+                trimap_generator=trimap_generator,
+                refining_module=cascadepsp,
+            )
+        elif config.postprocessing_method == "none":
+            postprocessing = None
+        else:
+            postprocessing = None
+
+        interface = Interface(
+            pre_pipe=preprocessing,
+            post_pipe=postprocessing,
+            seg_pipe=seg_net,
+            device=config.device,
+        )
+    return interface
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/utils/net_utils.html b/docs/api/carvekit/web/utils/net_utils.html new file mode 100644 index 0000000..4df1b17 --- /dev/null +++ b/docs/api/carvekit/web/utils/net_utils.html @@ -0,0 +1,139 @@ + + + + + + +carvekit.web.utils.net_utils API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.utils.net_utils

+
+
+
+ +Expand source code + +
import socket
+import struct
+from typing import Optional
+from urllib.parse import urlparse
+
+
+def is_loopback(address):
+    host: Optional[str] = None
+
+    try:
+        parsed_url = urlparse(address)
+        host = parsed_url.hostname
+    except ValueError:
+        return False  # url is not even a url
+
+    loopback_checker = {
+        socket.AF_INET: lambda x: struct.unpack("!I", socket.inet_aton(x))[0]
+        >> (32 - 8)
+        == 127,
+        socket.AF_INET6: lambda x: x == "::1",
+    }
+    for family in (socket.AF_INET, socket.AF_INET6):
+        try:
+            r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
+        except socket.gaierror:
+            continue
+        for family, _, _, _, sockaddr in r:
+            if loopback_checker[family](sockaddr[0]):
+                return True
+
+    if host in ("localhost",):
+        return True
+
+    return False
+
+
+
+
+
+
+
+

Functions

+
+
+def is_loopback(address) +
+
+
+
+ +Expand source code + +
def is_loopback(address):
+    host: Optional[str] = None
+
+    try:
+        parsed_url = urlparse(address)
+        host = parsed_url.hostname
+    except ValueError:
+        return False  # url is not even a url
+
+    loopback_checker = {
+        socket.AF_INET: lambda x: struct.unpack("!I", socket.inet_aton(x))[0]
+        >> (32 - 8)
+        == 127,
+        socket.AF_INET6: lambda x: x == "::1",
+    }
+    for family in (socket.AF_INET, socket.AF_INET6):
+        try:
+            r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
+        except socket.gaierror:
+            continue
+        for family, _, _, _, sockaddr in r:
+            if loopback_checker[family](sockaddr[0]):
+                return True
+
+    if host in ("localhost",):
+        return True
+
+    return False
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/carvekit/web/utils/task_queue.html b/docs/api/carvekit/web/utils/task_queue.html new file mode 100644 index 0000000..f685f87 --- /dev/null +++ b/docs/api/carvekit/web/utils/task_queue.html @@ -0,0 +1,481 @@ + + + + + + +carvekit.web.utils.task_queue API documentation + + + + + + + + + + + +
+
+
+

Module carvekit.web.utils.task_queue

+
+
+
+ +Expand source code + +
import gc
+import threading
+import time
+import uuid
+from typing import Optional
+
+from loguru import logger
+
+from carvekit.api.interface import Interface
+from carvekit.web.schemas.config import WebAPIConfig
+from carvekit.web.utils.init_utils import init_interface
+from carvekit.web.other.removebg import process_remove_bg
+
+
+class MLProcessor(threading.Thread):
+    """Simple ml task queue processor"""
+
+    def __init__(self, api_config: WebAPIConfig):
+        super().__init__()
+        self.api_config = api_config
+        self.interface: Optional[Interface] = None
+        self.jobs = {}
+        self.completed_jobs = {}
+
+    def run(self):
+        """Starts listening for new jobs."""
+        unused_completed_jobs_timer = time.time()
+        if self.interface is None:
+            self.interface = init_interface(self.api_config)
+        while True:
+            # Clear unused completed jobs every hour
+            if time.time() - unused_completed_jobs_timer > 60:
+                self.clear_old_completed_jobs()
+                unused_completed_jobs_timer = time.time()
+
+            if len(self.jobs.keys()) >= 1:
+                id = list(self.jobs.keys())[0]
+                data = self.jobs[id]
+                # TODO add pydantic scheme here
+                response = process_remove_bg(
+                    self.interface, data[0], data[1], data[2], data[3]
+                )
+                self.completed_jobs[id] = [response, time.time()]
+                try:
+                    del self.jobs[id]
+                except KeyError or NameError as e:
+                    logger.error(f"Something went wrong with Task Queue: {str(e)}")
+                gc.collect()
+            else:
+                time.sleep(1)
+                continue
+
+    def clear_old_completed_jobs(self):
+        """Clears old completed jobs"""
+
+        if len(self.completed_jobs.keys()) >= 1:
+            for job_id in self.completed_jobs.keys():
+                job_finished_time = self.completed_jobs[job_id][1]
+                if time.time() - job_finished_time > 3600:
+                    try:
+                        del self.completed_jobs[job_id]
+                    except KeyError or NameError as e:
+                        logger.error(f"Something went wrong with Task Queue: {str(e)}")
+            gc.collect()
+
+    def job_status(self, id: str) -> str:
+        """
+        Returns current job status
+
+        Args:
+            id: id of the job
+
+        Returns:
+            Current job status for specified id. Job status can be [finished, wait, not_found]
+        """
+        if id in self.completed_jobs.keys():
+            return "finished"
+        elif id in self.jobs.keys():
+            return "wait"
+        else:
+            return "not_found"
+
+    def job_result(self, id: str):
+        """
+        Returns job processing result.
+
+        Args:
+            id: id of the job
+
+        Returns:
+            job processing result.
+        """
+        if id in self.completed_jobs.keys():
+            data = self.completed_jobs[id][0]
+            try:
+                del self.completed_jobs[id]
+            except KeyError or NameError:
+                pass
+            return data
+        else:
+            return False
+
+    def job_create(self, data: list):
+        """
+        Send job to ML Processor
+
+        Args:
+            data: data object
+        """
+        if self.is_alive() is False:
+            self.start()
+        id = uuid.uuid4().hex
+        self.jobs[id] = data
+        return id
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class MLProcessor +(api_config:Β WebAPIConfig) +
+
+

Simple ml task queue processor

+

This constructor should always be called with keyword arguments. Arguments are:

+

group should be None; reserved for future extension when a ThreadGroup +class is implemented.

+

target is the callable object to be invoked by the run() +method. Defaults to None, meaning nothing is called.

+

name is the thread name. By default, a unique name is constructed of +the form "Thread-N" where N is a small decimal number.

+

args is the argument tuple for the target invocation. Defaults to ().

+

kwargs is a dictionary of keyword arguments for the target +invocation. Defaults to {}.

+

If a subclass overrides the constructor, it must make sure to invoke +the base class constructor (Thread.init()) before doing anything +else to the thread.

+
+ +Expand source code + +
class MLProcessor(threading.Thread):
+    """Simple ml task queue processor"""
+
+    def __init__(self, api_config: WebAPIConfig):
+        super().__init__()
+        self.api_config = api_config
+        self.interface: Optional[Interface] = None
+        self.jobs = {}
+        self.completed_jobs = {}
+
+    def run(self):
+        """Starts listening for new jobs."""
+        unused_completed_jobs_timer = time.time()
+        if self.interface is None:
+            self.interface = init_interface(self.api_config)
+        while True:
+            # Clear unused completed jobs every hour
+            if time.time() - unused_completed_jobs_timer > 60:
+                self.clear_old_completed_jobs()
+                unused_completed_jobs_timer = time.time()
+
+            if len(self.jobs.keys()) >= 1:
+                id = list(self.jobs.keys())[0]
+                data = self.jobs[id]
+                # TODO add pydantic scheme here
+                response = process_remove_bg(
+                    self.interface, data[0], data[1], data[2], data[3]
+                )
+                self.completed_jobs[id] = [response, time.time()]
+                try:
+                    del self.jobs[id]
+                except KeyError or NameError as e:
+                    logger.error(f"Something went wrong with Task Queue: {str(e)}")
+                gc.collect()
+            else:
+                time.sleep(1)
+                continue
+
+    def clear_old_completed_jobs(self):
+        """Clears old completed jobs"""
+
+        if len(self.completed_jobs.keys()) >= 1:
+            for job_id in self.completed_jobs.keys():
+                job_finished_time = self.completed_jobs[job_id][1]
+                if time.time() - job_finished_time > 3600:
+                    try:
+                        del self.completed_jobs[job_id]
+                    except KeyError or NameError as e:
+                        logger.error(f"Something went wrong with Task Queue: {str(e)}")
+            gc.collect()
+
+    def job_status(self, id: str) -> str:
+        """
+        Returns current job status
+
+        Args:
+            id: id of the job
+
+        Returns:
+            Current job status for specified id. Job status can be [finished, wait, not_found]
+        """
+        if id in self.completed_jobs.keys():
+            return "finished"
+        elif id in self.jobs.keys():
+            return "wait"
+        else:
+            return "not_found"
+
+    def job_result(self, id: str):
+        """
+        Returns job processing result.
+
+        Args:
+            id: id of the job
+
+        Returns:
+            job processing result.
+        """
+        if id in self.completed_jobs.keys():
+            data = self.completed_jobs[id][0]
+            try:
+                del self.completed_jobs[id]
+            except KeyError or NameError:
+                pass
+            return data
+        else:
+            return False
+
+    def job_create(self, data: list):
+        """
+        Send job to ML Processor
+
+        Args:
+            data: data object
+        """
+        if self.is_alive() is False:
+            self.start()
+        id = uuid.uuid4().hex
+        self.jobs[id] = data
+        return id
+
+

Ancestors

+
    +
  • threading.Thread
  • +
+

Methods

+
+
+def clear_old_completed_jobs(self) +
+
+

Clears old completed jobs

+
+ +Expand source code + +
def clear_old_completed_jobs(self):
+    """Clears old completed jobs"""
+
+    if len(self.completed_jobs.keys()) >= 1:
+        for job_id in self.completed_jobs.keys():
+            job_finished_time = self.completed_jobs[job_id][1]
+            if time.time() - job_finished_time > 3600:
+                try:
+                    del self.completed_jobs[job_id]
+                except KeyError or NameError as e:
+                    logger.error(f"Something went wrong with Task Queue: {str(e)}")
+        gc.collect()
+
+
+
+def job_create(self, data:Β list) +
+
+

Send job to ML Processor

+

Args

+
+
data
+
data object
+
+
+ +Expand source code + +
def job_create(self, data: list):
+    """
+    Send job to ML Processor
+
+    Args:
+        data: data object
+    """
+    if self.is_alive() is False:
+        self.start()
+    id = uuid.uuid4().hex
+    self.jobs[id] = data
+    return id
+
+
+
+def job_result(self, id:Β str) +
+
+

Returns job processing result.

+

Args

+
+
id
+
id of the job
+
+

Returns

+

job processing result.

+
+ +Expand source code + +
def job_result(self, id: str):
+    """
+    Returns job processing result.
+
+    Args:
+        id: id of the job
+
+    Returns:
+        job processing result.
+    """
+    if id in self.completed_jobs.keys():
+        data = self.completed_jobs[id][0]
+        try:
+            del self.completed_jobs[id]
+        except KeyError or NameError:
+            pass
+        return data
+    else:
+        return False
+
+
+
+def job_status(self, id:Β str) ‑>Β str +
+
+

Returns current job status

+

Args

+
+
id
+
id of the job
+
+

Returns

+

Current job status for specified id. Job status can be [finished, wait, not_found]

+
+ +Expand source code + +
def job_status(self, id: str) -> str:
+    """
+    Returns current job status
+
+    Args:
+        id: id of the job
+
+    Returns:
+        Current job status for specified id. Job status can be [finished, wait, not_found]
+    """
+    if id in self.completed_jobs.keys():
+        return "finished"
+    elif id in self.jobs.keys():
+        return "wait"
+    else:
+        return "not_found"
+
+
+
+def run(self) +
+
+

Starts listening for new jobs.

+
+ +Expand source code + +
def run(self):
+    """Starts listening for new jobs."""
+    unused_completed_jobs_timer = time.time()
+    if self.interface is None:
+        self.interface = init_interface(self.api_config)
+    while True:
+        # Clear unused completed jobs every hour
+        if time.time() - unused_completed_jobs_timer > 60:
+            self.clear_old_completed_jobs()
+            unused_completed_jobs_timer = time.time()
+
+        if len(self.jobs.keys()) >= 1:
+            id = list(self.jobs.keys())[0]
+            data = self.jobs[id]
+            # TODO add pydantic scheme here
+            response = process_remove_bg(
+                self.interface, data[0], data[1], data[2], data[3]
+            )
+            self.completed_jobs[id] = [response, time.time()]
+            try:
+                del self.jobs[id]
+            except KeyError or NameError as e:
+                logger.error(f"Something went wrong with Task Queue: {str(e)}")
+            gc.collect()
+        else:
+            time.sleep(1)
+            continue
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/cascadepsp.html b/docs/api/cascadepsp.html new file mode 100644 index 0000000..62d0a40 --- /dev/null +++ b/docs/api/cascadepsp.html @@ -0,0 +1,902 @@ + + + + + + +cascadepsp API documentation + + + + + + + + + + + +
+
+
+

Module cascadepsp

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import pathlib
+import warnings
+
+import PIL
+import numpy as np
+import torch
+from PIL import Image
+from torchvision import transforms
+from typing import Union, List
+
+from carvekit.ml.arch.cascadepsp.pspnet import RefinementModule
+from carvekit.ml.arch.cascadepsp.utils import (
+    process_im_single_pass,
+    process_high_res_im,
+)
+from carvekit.ml.files.models_loc import cascadepsp_pretrained
+from carvekit.utils.image_utils import convert_image, load_image
+from carvekit.utils.models_utils import get_precision_autocast, cast_network
+from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
+
+__all__ = ["CascadePSP"]
+
+
+class CascadePSP(RefinementModule):
+    """
+    CascadePSP to refine the mask from segmentation network
+    """
+
+    def __init__(
+        self,
+        device="cpu",
+        input_tensor_size: int = 900,
+        batch_size: int = 1,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+        mask_binary_threshold=127,
+        global_step_only=False,
+        processing_accelerate_image_size=2048,
+    ):
+        """
+        Initialize the CascadePSP model
+
+        Args:
+            device: processing device
+            input_tensor_size: input image size
+            batch_size: the number of images that the neural network processes in one run
+            load_pretrained: loading pretrained model
+            fp16: use half precision
+            global_step_only: if True, only global step will be used for prediction. See paper for details.
+            mask_binary_threshold: threshold for binary mask, default 70, set to 0 for no threshold
+            processing_accelerate_image_size: thumbnail size for image processing acceleration. Set to 0 to disable
+
+        """
+        super().__init__()
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        self.mask_binary_threshold = mask_binary_threshold
+        self.global_step_only = global_step_only
+        self.processing_accelerate_image_size = processing_accelerate_image_size
+        self.input_tensor_size = input_tensor_size
+
+        self.to(device)
+        if batch_size > 1:
+            warnings.warn(
+                "Batch size > 1 is experimental feature for CascadePSP."
+                " Please, don't use it if you have GPU with small memory!"
+            )
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(cascadepsp_pretrained(), map_location=self.device)
+            )
+        self.eval()
+
+        self._image_transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+
+        self._seg_transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(mean=[0.5], std=[0.5]),
+            ]
+        )
+
+    def data_preprocessing(self, data: Union[PIL.Image.Image]) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data: input image
+
+        Returns:
+            input for neural network
+
+        """
+        preprocessed_data = data.copy()
+        if self.batch_size == 1 and self.processing_accelerate_image_size > 0:
+            # Okay, we have only one image, so
+            # we can use image processing acceleration for accelerate high resolution image processing
+            preprocessed_data.thumbnail(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+        elif self.batch_size == 1:
+            pass  # No need to do anything
+        elif self.batch_size > 1 and self.global_step_only is True:
+            # If we have more than one image and we use only global step,
+            # there aren't any reason to use image processing acceleration,
+            # because we will use only global step for prediction and anyway it will be resized to input_tensor_size
+            preprocessed_data = preprocessed_data.resize(
+                (self.input_tensor_size, self.input_tensor_size)
+            )
+        elif (
+            self.batch_size > 1
+            and self.global_step_only is False
+            and self.processing_accelerate_image_size > 0
+        ):
+            # If we have more than one image and we use local step,
+            # we can use image processing acceleration for accelerate high resolution image processing
+            # but we need to resize image to processing_accelerate_image_size to stack it with other images
+            preprocessed_data = preprocessed_data.resize(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+        elif (
+            self.batch_size > 1
+            and self.global_step_only is False
+            and not (self.processing_accelerate_image_size > 0)
+        ):
+            raise ValueError(
+                "If you use local step with batch_size > 2, "
+                "you need to set processing_accelerate_image_size > 0,"
+                "since we cannot stack images with different sizes to one batch"
+            )
+        else:  # some extra cases
+            preprocessed_data = preprocessed_data.resize(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+
+        if data.mode == "RGB":
+            preprocessed_data = self._image_transform(
+                np.array(preprocessed_data)
+            ).unsqueeze(0)
+        elif data.mode == "L":
+            preprocessed_data = np.array(preprocessed_data)
+            if 0 < self.mask_binary_threshold <= 255:
+                preprocessed_data = (
+                    preprocessed_data > self.mask_binary_threshold
+                ).astype(np.uint8) * 255
+            elif self.mask_binary_threshold > 255 or self.mask_binary_threshold < 0:
+                warnings.warn(
+                    "mask_binary_threshold should be in range [0, 255], "
+                    "but got {}. Disabling mask_binary_threshold!".format(
+                        self.mask_binary_threshold
+                    )
+                )
+
+            preprocessed_data = self._seg_transform(preprocessed_data).unsqueeze(
+                0
+            )  # [H,W,1]
+
+        return preprocessed_data
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, mask: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+            mask: input mask
+
+        Returns:
+            Segmentation mask as PIL Image instance
+
+        """
+        refined_mask = (data[0, :, :].cpu().numpy() * 255).astype("uint8")
+        return Image.fromarray(refined_mask).convert("L").resize(mask.size)
+
+    def safe_forward(self, im, seg, inter_s8=None, inter_s4=None):
+        """
+        Slightly pads the input image such that its length is a multiple of 8
+        """
+        b, _, ph, pw = seg.shape
+        if (ph % 8 != 0) or (pw % 8 != 0):
+            newH = (ph // 8 + 1) * 8
+            newW = (pw // 8 + 1) * 8
+            p_im = torch.zeros(b, 3, newH, newW, device=im.device)
+            p_seg = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+
+            p_im[:, :, 0:ph, 0:pw] = im
+            p_seg[:, :, 0:ph, 0:pw] = seg
+            im = p_im
+            seg = p_seg
+
+            if inter_s8 is not None:
+                p_inter_s8 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+                p_inter_s8[:, :, 0:ph, 0:pw] = inter_s8
+                inter_s8 = p_inter_s8
+            if inter_s4 is not None:
+                p_inter_s4 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+                p_inter_s4[:, :, 0:ph, 0:pw] = inter_s4
+                inter_s4 = p_inter_s4
+
+        images = super().__call__(im, seg, inter_s8, inter_s4)
+        return_im = {}
+
+        for key in ["pred_224", "pred_28_3", "pred_56_2"]:
+            return_im[key] = images[key][:, :, 0:ph, 0:pw]
+        del images
+
+        return return_im
+
+    def __call__(
+        self,
+        images: List[Union[str, pathlib.Path, PIL.Image.Image]],
+        masks: List[Union[str, pathlib.Path, PIL.Image.Image]],
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images: input images
+            masks: Segmentation masks to refine
+
+        Returns:
+            segmentation masks as for input images, as PIL.Image.Image instances
+
+        """
+
+        if len(images) != len(masks):
+            raise ValueError(
+                "Len of specified arrays of images and trimaps should be equal!"
+            )
+
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for idx_batch in batch_generator(range(len(images)), self.batch_size):
+                inpt_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(images[x])), idx_batch
+                )
+
+                inpt_masks = thread_pool_processing(
+                    lambda x: convert_image(load_image(masks[x]), mode="L"), idx_batch
+                )
+
+                inpt_img_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_images
+                )
+                inpt_masks_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_masks
+                )
+                if self.batch_size > 1:  # We need to stack images, if batch_size > 1
+                    inpt_img_batches = torch.vstack(inpt_img_batches)
+                    inpt_masks_batches = torch.vstack(inpt_masks_batches)
+                else:
+                    inpt_img_batches = inpt_img_batches[
+                        0
+                    ]  # Get only one image from list
+                    inpt_masks_batches = inpt_masks_batches[0]
+
+                with torch.no_grad():
+                    inpt_img_batches = inpt_img_batches.to(self.device)
+                    inpt_masks_batches = inpt_masks_batches.to(self.device)
+                    if self.global_step_only:
+                        refined_batches = process_im_single_pass(
+                            self,
+                            inpt_img_batches,
+                            inpt_masks_batches,
+                            self.input_tensor_size,
+                        )
+
+                    else:
+                        refined_batches = process_high_res_im(
+                            self,
+                            inpt_img_batches,
+                            inpt_masks_batches,
+                            self.input_tensor_size,
+                        )
+
+                    refined_masks = refined_batches.cpu()
+                    del (inpt_img_batches, inpt_masks_batches, refined_batches)
+                collect_masks += thread_pool_processing(
+                    lambda x: self.data_postprocessing(refined_masks[x], inpt_masks[x]),
+                    range(len(inpt_masks)),
+                )
+            return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class CascadePSP +(device='cpu', input_tensor_size:Β intΒ =Β 900, batch_size:Β intΒ =Β 1, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False, mask_binary_threshold=127, global_step_only=False, processing_accelerate_image_size=2048) +
+
+

CascadePSP to refine the mask from segmentation network

+

Initialize the CascadePSP model

+

Args

+
+
device
+
processing device
+
input_tensor_size
+
input image size
+
batch_size
+
the number of images that the neural network processes in one run
+
load_pretrained
+
loading pretrained model
+
fp16
+
use half precision
+
global_step_only
+
if True, only global step will be used for prediction. See paper for details.
+
mask_binary_threshold
+
threshold for binary mask, default 70, set to 0 for no threshold
+
processing_accelerate_image_size
+
thumbnail size for image processing acceleration. Set to 0 to disable
+
+
+ +Expand source code + +
class CascadePSP(RefinementModule):
+    """
+    CascadePSP to refine the mask from segmentation network
+    """
+
+    def __init__(
+        self,
+        device="cpu",
+        input_tensor_size: int = 900,
+        batch_size: int = 1,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+        mask_binary_threshold=127,
+        global_step_only=False,
+        processing_accelerate_image_size=2048,
+    ):
+        """
+        Initialize the CascadePSP model
+
+        Args:
+            device: processing device
+            input_tensor_size: input image size
+            batch_size: the number of images that the neural network processes in one run
+            load_pretrained: loading pretrained model
+            fp16: use half precision
+            global_step_only: if True, only global step will be used for prediction. See paper for details.
+            mask_binary_threshold: threshold for binary mask, default 70, set to 0 for no threshold
+            processing_accelerate_image_size: thumbnail size for image processing acceleration. Set to 0 to disable
+
+        """
+        super().__init__()
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        self.mask_binary_threshold = mask_binary_threshold
+        self.global_step_only = global_step_only
+        self.processing_accelerate_image_size = processing_accelerate_image_size
+        self.input_tensor_size = input_tensor_size
+
+        self.to(device)
+        if batch_size > 1:
+            warnings.warn(
+                "Batch size > 1 is experimental feature for CascadePSP."
+                " Please, don't use it if you have GPU with small memory!"
+            )
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(cascadepsp_pretrained(), map_location=self.device)
+            )
+        self.eval()
+
+        self._image_transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+
+        self._seg_transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(mean=[0.5], std=[0.5]),
+            ]
+        )
+
+    def data_preprocessing(self, data: Union[PIL.Image.Image]) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data: input image
+
+        Returns:
+            input for neural network
+
+        """
+        preprocessed_data = data.copy()
+        if self.batch_size == 1 and self.processing_accelerate_image_size > 0:
+            # Okay, we have only one image, so
+            # we can use image processing acceleration for accelerate high resolution image processing
+            preprocessed_data.thumbnail(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+        elif self.batch_size == 1:
+            pass  # No need to do anything
+        elif self.batch_size > 1 and self.global_step_only is True:
+            # If we have more than one image and we use only global step,
+            # there aren't any reason to use image processing acceleration,
+            # because we will use only global step for prediction and anyway it will be resized to input_tensor_size
+            preprocessed_data = preprocessed_data.resize(
+                (self.input_tensor_size, self.input_tensor_size)
+            )
+        elif (
+            self.batch_size > 1
+            and self.global_step_only is False
+            and self.processing_accelerate_image_size > 0
+        ):
+            # If we have more than one image and we use local step,
+            # we can use image processing acceleration for accelerate high resolution image processing
+            # but we need to resize image to processing_accelerate_image_size to stack it with other images
+            preprocessed_data = preprocessed_data.resize(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+        elif (
+            self.batch_size > 1
+            and self.global_step_only is False
+            and not (self.processing_accelerate_image_size > 0)
+        ):
+            raise ValueError(
+                "If you use local step with batch_size > 2, "
+                "you need to set processing_accelerate_image_size > 0,"
+                "since we cannot stack images with different sizes to one batch"
+            )
+        else:  # some extra cases
+            preprocessed_data = preprocessed_data.resize(
+                (
+                    self.processing_accelerate_image_size,
+                    self.processing_accelerate_image_size,
+                )
+            )
+
+        if data.mode == "RGB":
+            preprocessed_data = self._image_transform(
+                np.array(preprocessed_data)
+            ).unsqueeze(0)
+        elif data.mode == "L":
+            preprocessed_data = np.array(preprocessed_data)
+            if 0 < self.mask_binary_threshold <= 255:
+                preprocessed_data = (
+                    preprocessed_data > self.mask_binary_threshold
+                ).astype(np.uint8) * 255
+            elif self.mask_binary_threshold > 255 or self.mask_binary_threshold < 0:
+                warnings.warn(
+                    "mask_binary_threshold should be in range [0, 255], "
+                    "but got {}. Disabling mask_binary_threshold!".format(
+                        self.mask_binary_threshold
+                    )
+                )
+
+            preprocessed_data = self._seg_transform(preprocessed_data).unsqueeze(
+                0
+            )  # [H,W,1]
+
+        return preprocessed_data
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, mask: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+            mask: input mask
+
+        Returns:
+            Segmentation mask as PIL Image instance
+
+        """
+        refined_mask = (data[0, :, :].cpu().numpy() * 255).astype("uint8")
+        return Image.fromarray(refined_mask).convert("L").resize(mask.size)
+
+    def safe_forward(self, im, seg, inter_s8=None, inter_s4=None):
+        """
+        Slightly pads the input image such that its length is a multiple of 8
+        """
+        b, _, ph, pw = seg.shape
+        if (ph % 8 != 0) or (pw % 8 != 0):
+            newH = (ph // 8 + 1) * 8
+            newW = (pw // 8 + 1) * 8
+            p_im = torch.zeros(b, 3, newH, newW, device=im.device)
+            p_seg = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+
+            p_im[:, :, 0:ph, 0:pw] = im
+            p_seg[:, :, 0:ph, 0:pw] = seg
+            im = p_im
+            seg = p_seg
+
+            if inter_s8 is not None:
+                p_inter_s8 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+                p_inter_s8[:, :, 0:ph, 0:pw] = inter_s8
+                inter_s8 = p_inter_s8
+            if inter_s4 is not None:
+                p_inter_s4 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+                p_inter_s4[:, :, 0:ph, 0:pw] = inter_s4
+                inter_s4 = p_inter_s4
+
+        images = super().__call__(im, seg, inter_s8, inter_s4)
+        return_im = {}
+
+        for key in ["pred_224", "pred_28_3", "pred_56_2"]:
+            return_im[key] = images[key][:, :, 0:ph, 0:pw]
+        del images
+
+        return return_im
+
+    def __call__(
+        self,
+        images: List[Union[str, pathlib.Path, PIL.Image.Image]],
+        masks: List[Union[str, pathlib.Path, PIL.Image.Image]],
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images: input images
+            masks: Segmentation masks to refine
+
+        Returns:
+            segmentation masks as for input images, as PIL.Image.Image instances
+
+        """
+
+        if len(images) != len(masks):
+            raise ValueError(
+                "Len of specified arrays of images and trimaps should be equal!"
+            )
+
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for idx_batch in batch_generator(range(len(images)), self.batch_size):
+                inpt_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(images[x])), idx_batch
+                )
+
+                inpt_masks = thread_pool_processing(
+                    lambda x: convert_image(load_image(masks[x]), mode="L"), idx_batch
+                )
+
+                inpt_img_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_images
+                )
+                inpt_masks_batches = thread_pool_processing(
+                    self.data_preprocessing, inpt_masks
+                )
+                if self.batch_size > 1:  # We need to stack images, if batch_size > 1
+                    inpt_img_batches = torch.vstack(inpt_img_batches)
+                    inpt_masks_batches = torch.vstack(inpt_masks_batches)
+                else:
+                    inpt_img_batches = inpt_img_batches[
+                        0
+                    ]  # Get only one image from list
+                    inpt_masks_batches = inpt_masks_batches[0]
+
+                with torch.no_grad():
+                    inpt_img_batches = inpt_img_batches.to(self.device)
+                    inpt_masks_batches = inpt_masks_batches.to(self.device)
+                    if self.global_step_only:
+                        refined_batches = process_im_single_pass(
+                            self,
+                            inpt_img_batches,
+                            inpt_masks_batches,
+                            self.input_tensor_size,
+                        )
+
+                    else:
+                        refined_batches = process_high_res_im(
+                            self,
+                            inpt_img_batches,
+                            inpt_masks_batches,
+                            self.input_tensor_size,
+                        )
+
+                    refined_masks = refined_batches.cpu()
+                    del (inpt_img_batches, inpt_masks_batches, refined_batches)
+                collect_masks += thread_pool_processing(
+                    lambda x: self.data_postprocessing(refined_masks[x], inpt_masks[x]),
+                    range(len(inpt_masks)),
+                )
+            return collect_masks
+
+

Ancestors

+ +

Static methods

+
+
+def data_postprocessing(data:Β torch.Tensor, mask:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data
+
output data from neural network
+
mask
+
input mask
+
+

Returns

+

Segmentation mask as PIL Image instance

+
+ +Expand source code + +
@staticmethod
+def data_postprocessing(
+    data: torch.Tensor, mask: PIL.Image.Image
+) -> PIL.Image.Image:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data: output data from neural network
+        mask: input mask
+
+    Returns:
+        Segmentation mask as PIL Image instance
+
+    """
+    refined_mask = (data[0, :, :].cpu().numpy() * 255).astype("uint8")
+    return Image.fromarray(refined_mask).convert("L").resize(mask.size)
+
+
+
+

Methods

+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.FloatTensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data
+
input image
+
+

Returns

+

input for neural network

+
+ +Expand source code + +
def data_preprocessing(self, data: Union[PIL.Image.Image]) -> torch.FloatTensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data: input image
+
+    Returns:
+        input for neural network
+
+    """
+    preprocessed_data = data.copy()
+    if self.batch_size == 1 and self.processing_accelerate_image_size > 0:
+        # Okay, we have only one image, so
+        # we can use image processing acceleration for accelerate high resolution image processing
+        preprocessed_data.thumbnail(
+            (
+                self.processing_accelerate_image_size,
+                self.processing_accelerate_image_size,
+            )
+        )
+    elif self.batch_size == 1:
+        pass  # No need to do anything
+    elif self.batch_size > 1 and self.global_step_only is True:
+        # If we have more than one image and we use only global step,
+        # there aren't any reason to use image processing acceleration,
+        # because we will use only global step for prediction and anyway it will be resized to input_tensor_size
+        preprocessed_data = preprocessed_data.resize(
+            (self.input_tensor_size, self.input_tensor_size)
+        )
+    elif (
+        self.batch_size > 1
+        and self.global_step_only is False
+        and self.processing_accelerate_image_size > 0
+    ):
+        # If we have more than one image and we use local step,
+        # we can use image processing acceleration for accelerate high resolution image processing
+        # but we need to resize image to processing_accelerate_image_size to stack it with other images
+        preprocessed_data = preprocessed_data.resize(
+            (
+                self.processing_accelerate_image_size,
+                self.processing_accelerate_image_size,
+            )
+        )
+    elif (
+        self.batch_size > 1
+        and self.global_step_only is False
+        and not (self.processing_accelerate_image_size > 0)
+    ):
+        raise ValueError(
+            "If you use local step with batch_size > 2, "
+            "you need to set processing_accelerate_image_size > 0,"
+            "since we cannot stack images with different sizes to one batch"
+        )
+    else:  # some extra cases
+        preprocessed_data = preprocessed_data.resize(
+            (
+                self.processing_accelerate_image_size,
+                self.processing_accelerate_image_size,
+            )
+        )
+
+    if data.mode == "RGB":
+        preprocessed_data = self._image_transform(
+            np.array(preprocessed_data)
+        ).unsqueeze(0)
+    elif data.mode == "L":
+        preprocessed_data = np.array(preprocessed_data)
+        if 0 < self.mask_binary_threshold <= 255:
+            preprocessed_data = (
+                preprocessed_data > self.mask_binary_threshold
+            ).astype(np.uint8) * 255
+        elif self.mask_binary_threshold > 255 or self.mask_binary_threshold < 0:
+            warnings.warn(
+                "mask_binary_threshold should be in range [0, 255], "
+                "but got {}. Disabling mask_binary_threshold!".format(
+                    self.mask_binary_threshold
+                )
+            )
+
+        preprocessed_data = self._seg_transform(preprocessed_data).unsqueeze(
+            0
+        )  # [H,W,1]
+
+    return preprocessed_data
+
+
+
+def safe_forward(self, im, seg, inter_s8=None, inter_s4=None) +
+
+

Slightly pads the input image such that its length is a multiple of 8

+
+ +Expand source code + +
def safe_forward(self, im, seg, inter_s8=None, inter_s4=None):
+    """
+    Slightly pads the input image such that its length is a multiple of 8
+    """
+    b, _, ph, pw = seg.shape
+    if (ph % 8 != 0) or (pw % 8 != 0):
+        newH = (ph // 8 + 1) * 8
+        newW = (pw // 8 + 1) * 8
+        p_im = torch.zeros(b, 3, newH, newW, device=im.device)
+        p_seg = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+
+        p_im[:, :, 0:ph, 0:pw] = im
+        p_seg[:, :, 0:ph, 0:pw] = seg
+        im = p_im
+        seg = p_seg
+
+        if inter_s8 is not None:
+            p_inter_s8 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+            p_inter_s8[:, :, 0:ph, 0:pw] = inter_s8
+            inter_s8 = p_inter_s8
+        if inter_s4 is not None:
+            p_inter_s4 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
+            p_inter_s4[:, :, 0:ph, 0:pw] = inter_s4
+            inter_s4 = p_inter_s4
+
+    images = super().__call__(im, seg, inter_s8, inter_s4)
+    return_im = {}
+
+    for key in ["pred_224", "pred_28_3", "pred_56_2"]:
+        return_im[key] = images[key][:, :, 0:ph, 0:pw]
+    del images
+
+    return return_im
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/deeplab_v3.html b/docs/api/deeplab_v3.html new file mode 100644 index 0000000..65e14cc --- /dev/null +++ b/docs/api/deeplab_v3.html @@ -0,0 +1,485 @@ + + + + + + +deeplab_v3 API documentation + + + + + + + + + + + +
+
+
+

Module deeplab_v3

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+import pathlib
+from typing import List, Union
+
+import PIL.Image
+import torch
+from PIL import Image
+from torchvision import transforms
+from torchvision.models.segmentation import deeplabv3_resnet101
+from carvekit.ml.files.models_loc import deeplab_pretrained
+from carvekit.utils.image_utils import convert_image, load_image
+from carvekit.utils.models_utils import get_precision_autocast, cast_network
+from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
+
+__all__ = ["DeepLabV3"]
+
+
+class DeepLabV3:
+    def __init__(
+        self,
+        device="cpu",
+        batch_size: int = 10,
+        input_image_size: Union[List[int], int] = 1024,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the `DeepLabV3` model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (): input image size
+            batch_size (int, default=10): the number of images that the neural network processes in one run
+            load_pretrained (bool, default=True): loading pretrained model
+            fp16 (bool, default=False): use half precision
+
+        """
+        self.device = device
+        self.batch_size = batch_size
+        self.network = deeplabv3_resnet101(
+            pretrained=False, pretrained_backbone=False, aux_loss=True
+        )
+        self.network.to(self.device)
+        if load_pretrained:
+            self.network.load_state_dict(
+                torch.load(deeplab_pretrained(), map_location=self.device)
+            )
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.network.eval()
+        self.fp16 = fp16
+        self.transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+
+    def to(self, device: str):
+        """
+        Moves neural network to specified processing device
+
+        Args:
+            device (Literal[cpu, cuda]): the desired device.
+
+        """
+        self.network.to(device)
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.Tensor: input for neural network
+
+        """
+        copy = data.copy()
+        copy.thumbnail(self.input_image_size, resample=3)
+        return self.transform(copy)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        return (
+            Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size)
+        )
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as `PIL.Image.Image` instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images, as `PIL.Image.Image` instances
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self.network, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = thread_pool_processing(
+                    self.data_preprocessing, converted_images
+                )
+                with torch.no_grad():
+                    masks = [
+                        self.network(i.to(self.device).unsqueeze(0))["out"][0]
+                        .argmax(0)
+                        .byte()
+                        .cpu()
+                        for i in batches
+                    ]
+                    del batches
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(masks[x], converted_images[x]),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+        return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class DeepLabV3 +(device='cpu', batch_size:Β intΒ =Β 10, input_image_size:Β Union[List[int],Β int]Β =Β 1024, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False) +
+
+

Initialize the DeepLabV3 model

+

Args

+
+
device : Literal[cpu, cuda], default=cpu
+
processing device
+
input_image_size (): input image size
+
batch_size : int, default=10
+
the number of images that the neural network processes in one run
+
load_pretrained : bool, default=True
+
loading pretrained model
+
fp16 : bool, default=False
+
use half precision
+
+
+ +Expand source code + +
class DeepLabV3:
+    def __init__(
+        self,
+        device="cpu",
+        batch_size: int = 10,
+        input_image_size: Union[List[int], int] = 1024,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the `DeepLabV3` model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (): input image size
+            batch_size (int, default=10): the number of images that the neural network processes in one run
+            load_pretrained (bool, default=True): loading pretrained model
+            fp16 (bool, default=False): use half precision
+
+        """
+        self.device = device
+        self.batch_size = batch_size
+        self.network = deeplabv3_resnet101(
+            pretrained=False, pretrained_backbone=False, aux_loss=True
+        )
+        self.network.to(self.device)
+        if load_pretrained:
+            self.network.load_state_dict(
+                torch.load(deeplab_pretrained(), map_location=self.device)
+            )
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.network.eval()
+        self.fp16 = fp16
+        self.transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+
+    def to(self, device: str):
+        """
+        Moves neural network to specified processing device
+
+        Args:
+            device (Literal[cpu, cuda]): the desired device.
+
+        """
+        self.network.to(device)
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.Tensor: input for neural network
+
+        """
+        copy = data.copy()
+        copy.thumbnail(self.input_image_size, resample=3)
+        return self.transform(copy)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        return (
+            Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size)
+        )
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as `PIL.Image.Image` instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images, as `PIL.Image.Image` instances
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self.network, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = thread_pool_processing(
+                    self.data_preprocessing, converted_images
+                )
+                with torch.no_grad():
+                    masks = [
+                        self.network(i.to(self.device).unsqueeze(0))["out"][0]
+                        .argmax(0)
+                        .byte()
+                        .cpu()
+                        for i in batches
+                    ]
+                    del batches
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(masks[x], converted_images[x]),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+        return collect_masks
+
+

Static methods

+
+
+def data_postprocessing(data:Β torch.Tensor, original_image:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data : torch.Tensor
+
output data from neural network
+
original_image : PIL.Image.Image
+
input image which was used for predicted data
+
+

Returns

+
+
PIL.Image.Image
+
Segmentation mask as PIL Image instance
+
+
+ +Expand source code + +
@staticmethod
+def data_postprocessing(
+    data: torch.Tensor, original_image: PIL.Image.Image
+) -> PIL.Image.Image:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data (torch.Tensor): output data from neural network
+        original_image (PIL.Image.Image): input image which was used for predicted data
+
+    Returns:
+        PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+    """
+    return (
+        Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size)
+    )
+
+
+
+

Methods

+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.Tensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data : PIL.Image.Image
+
input image
+
+

Returns

+
+
torch.Tensor
+
input for neural network
+
+
+ +Expand source code + +
def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data (PIL.Image.Image): input image
+
+    Returns:
+        torch.Tensor: input for neural network
+
+    """
+    copy = data.copy()
+    copy.thumbnail(self.input_image_size, resample=3)
+    return self.transform(copy)
+
+
+
+def to(self, device:Β str) +
+
+

Moves neural network to specified processing device

+

Args

+
+
device : Literal[cpu, cuda]
+
the desired device.
+
+
+ +Expand source code + +
def to(self, device: str):
+    """
+    Moves neural network to specified processing device
+
+    Args:
+        device (Literal[cpu, cuda]): the desired device.
+
+    """
+    self.network.to(device)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/download_models.html b/docs/api/download_models.html new file mode 100644 index 0000000..643813d --- /dev/null +++ b/docs/api/download_models.html @@ -0,0 +1,775 @@ + + + + + + +download_models API documentation + + + + + + + + + + + +
+
+
+

Module download_models

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import hashlib
+import os
+import warnings
+from abc import ABCMeta, abstractmethod, ABC
+from pathlib import Path
+from typing import Optional
+
+import carvekit
+from carvekit.ml.files import checkpoints_dir
+
+import requests
+import tqdm
+
+requests = requests.Session()
+requests.headers.update({"User-Agent": f"Carvekit/{carvekit.version}"})
+
+MODELS_URLS = {
+    "basnet.pth": {
+        "repository": "Carve/basnet-universal",
+        "revision": "870becbdb364fda6d8fdb2c10b072542f8d08701",
+        "filename": "basnet.pth",
+    },
+    "deeplab.pth": {
+        "repository": "Carve/deeplabv3-resnet101",
+        "revision": "d504005392fc877565afdf58aad0cd524682d2b0",
+        "filename": "deeplab.pth",
+    },
+    "fba_matting.pth": {
+        "repository": "Carve/fba",
+        "revision": "a5d3457df0fb9c88ea19ed700d409756ca2069d1",
+        "filename": "fba_matting.pth",
+    },
+    "u2net.pth": {
+        "repository": "Carve/u2net-universal",
+        "revision": "10305d785481cf4b2eee1d447c39cd6e5f43d74b",
+        "filename": "full_weights.pth",
+    },
+    "tracer_b7.pth": {
+        "repository": "Carve/tracer_b7",
+        "revision": "d8a8fd9e7b3fa0d2f1506fe7242966b34381e9c5",
+        "filename": "tracer_b7.pth",
+    },
+    "scene_classifier.pth": {
+        "repository": "Carve/scene_classifier",
+        "revision": "71c8e4c771dd5a20ff0c5c9e3c8f1c9cf8082740",
+        "filename": "scene_classifier.pth",
+    },
+    "yolov4_coco_with_classes.pth": {
+        "repository": "Carve/yolov4_coco",
+        "revision": "e3fc9cd22f86e456d2749d1ae148400f2f950fb3",
+        "filename": "yolov4_coco_with_classes.pth",
+    },
+    "cascadepsp.pth": {
+        "repository": "Carve/cascadepsp",
+        "revision": "3ca1e5e432344b1277bc88d1c6d4265c46cff62f",
+        "filename": "cascadepsp.pth",
+    },
+}
+"""
+All data needed to build path relative to huggingface.co for model download
+"""
+
+MODELS_CHECKSUMS = {
+    "basnet.pth": "e409cb709f4abca87cb11bd44a9ad3f909044a917977ab65244b4c94dd33"
+    "8b1a37755c4253d7cb54526b7763622a094d7b676d34b5e6886689256754e5a5e6ad",
+    "deeplab.pth": "9c5a1795bc8baa267200a44b49ac544a1ba2687d210f63777e4bd715387324469a59b072f8a28"
+    "9cc471c637b367932177e5b312e8ea6351c1763d9ff44b4857c",
+    "fba_matting.pth": "890906ec94c1bfd2ad08707a63e4ccb0955d7f5d25e32853950c24c78"
+    "4cbad2e59be277999defc3754905d0f15aa75702cdead3cfe669ff72f08811c52971613",
+    "u2net.pth": "16f8125e2fedd8c85db0e001ee15338b4aa2fda77bab8ba70c25e"
+    "bea1533fda5ee70a909b934a9bd495b432cef89d629f00a07858a517742476fa8b346de24f7",
+    "tracer_b7.pth": "c439c5c12d4d43d5f9be9ec61e68b2e54658a541bccac2577ef5a54fb252b6e8415d41f7e"
+    "c2487033d0c02b4dd08367958e4e62091318111c519f93e2632be7b",
+    "scene_classifier.pth": "6d8692510abde453b406a1fea557afdea62fd2a2a2677283a3ecc2"
+    "341a4895ee99ed65cedcb79b80775db14c3ffcfc0aad2caec1d85140678852039d2d4e76b4",
+    "yolov4_coco_with_classes.pth": "44b6ec2dd35dc3802bf8c512002f76e00e97bfbc86bc7af6de2fafce229a41b4ca"
+    "12c6f3d7589278c71cd4ddd62df80389b148c19b84fa03216905407a107fff",
+    "cascadepsp.pth": "3f895f5126d80d6f73186f045557ea7c8eab4dfa3d69a995815bb2c03d564573f36c474f04d7bf0022a27829f583a1a793b036adf801cb423e41a4831b830122",
+}
+"""
+Model -> checksum dictionary
+"""
+
+
+def sha512_checksum_calc(file: Path) -> str:
+    """
+    Calculates the SHA512 hash digest of a file on fs
+
+    Args:
+        file (Path): Path to the file
+
+    Returns:
+        SHA512 hash digest of a file.
+    """
+    dd = hashlib.sha512()
+    with file.open("rb") as f:
+        for chunk in iter(lambda: f.read(4096), b""):
+            dd.update(chunk)
+    return dd.hexdigest()
+
+
+class CachedDownloader:
+    """
+    Metaclass for models downloaders.
+    """
+
+    __metaclass__ = ABCMeta
+
+    @property
+    @abstractmethod
+    def name(self) -> str:
+        return self.__class__.__name__
+
+    @property
+    @abstractmethod
+    def fallback_downloader(self) -> Optional["CachedDownloader"]:
+        """
+        Property MAY be overriden in subclasses.
+        Used in case if subclass failed to download model. So preferred downloader SHOULD be placed higher in the hierarchy.
+        Less preferred downloader SHOULD be provided by this property.
+        """
+        pass
+
+    def download_model(self, file_name: str) -> Path:
+        """
+        Downloads model from the internet and saves it to the cache.
+
+        Behavior:
+            If model is already downloaded it will be loaded from the cache.
+
+            If model is already downloaded, but checksum is invalid, it will be downloaded again.
+
+            If model download failed, fallback downloader will be used.
+        """
+        try:
+            return self.download_model_base(file_name)
+        except BaseException as e:
+            if self.fallback_downloader is not None:
+                warnings.warn(
+                    f"Failed to download model from {self.name} downloader."
+                    f" Trying to download from {self.fallback_downloader.name} downloader."
+                )
+                return self.fallback_downloader.download_model(file_name)
+            else:
+                warnings.warn(
+                    f"Failed to download model from {self.name} downloader."
+                    f" No fallback downloader available."
+                )
+                raise e
+
+    @abstractmethod
+    def download_model_base(self, model_name: str) -> Path:
+        """
+        Download model from any source if not cached.
+        Returns:
+            pathlib.Path: Path to the downloaded model.
+        """
+
+    def __call__(self, model_name: str):
+        return self.download_model(model_name)
+
+
+class HuggingFaceCompatibleDownloader(CachedDownloader, ABC):
+    """
+    Downloader for models from HuggingFace Hub.
+    Private models are not supported.
+    """
+
+    def __init__(
+        self,
+        name: str = "Huggingface.co",
+        base_url: str = "https://huggingface.co",
+        fb_downloader: Optional["CachedDownloader"] = None,
+    ):
+        self.cache_dir = checkpoints_dir
+        """SHOULD be same for all instances to prevent downloading same model multiple times
+        Points to ~/.cache/carvekit/checkpoints"""
+        self.base_url = base_url
+        """MUST be a base url with protocol and domain name to huggingface or another, compatible in terms of models downloading API source"""
+        self._name = name
+        self._fallback_downloader = fb_downloader
+
+    @property
+    def fallback_downloader(self) -> Optional["CachedDownloader"]:
+        return self._fallback_downloader
+
+    @property
+    def name(self):
+        return self._name
+
+    def check_for_existence(self, model_name: str) -> Optional[Path]:
+        """
+        Checks if model is already downloaded and cached. Verifies file integrity by checksum.
+        Returns:
+            Optional[pathlib.Path]: Path to the cached model if cached.
+        """
+        if model_name not in MODELS_URLS.keys():
+            raise FileNotFoundError("Unknown model!")
+        path = (
+            self.cache_dir
+            / MODELS_URLS[model_name]["repository"].split("/")[1]
+            / model_name
+        )
+
+        if not path.exists():
+            return None
+
+        if MODELS_CHECKSUMS[path.name] != sha512_checksum_calc(path):
+            warnings.warn(
+                f"Invalid checksum for model {path.name}. Downloading correct model!"
+            )
+            os.remove(path)
+            return None
+        return path
+
+    def download_model_base(self, model_name: str) -> Path:
+        cached_path = self.check_for_existence(model_name)
+        if cached_path is not None:
+            return cached_path
+        else:
+            cached_path = (
+                self.cache_dir
+                / MODELS_URLS[model_name]["repository"].split("/")[1]
+                / model_name
+            )
+            cached_path.parent.mkdir(parents=True, exist_ok=True)
+            url = MODELS_URLS[model_name]
+            hugging_face_url = f"{self.base_url}/{url['repository']}/resolve/{url['revision']}/{url['filename']}"
+
+            try:
+                r = requests.get(hugging_face_url, stream=True, timeout=10)
+                if r.status_code < 400:
+                    with open(cached_path, "wb") as f:
+                        r.raw.decode_content = True
+                        for chunk in tqdm.tqdm(
+                            r,
+                            desc="Downloading " + cached_path.name + " model",
+                            colour="blue",
+                        ):
+                            f.write(chunk)
+                else:
+                    if r.status_code == 404:
+                        raise FileNotFoundError(f"Model {model_name} not found!")
+                    else:
+                        raise ConnectionError(
+                            f"Error {r.status_code} while downloading model {model_name}!"
+                        )
+            except BaseException as e:
+                if cached_path.exists():
+                    os.remove(cached_path)
+                raise ConnectionError(
+                    f"Exception caught when downloading model! "
+                    f"Model name: {cached_path.name}. Exception: {str(e)}."
+                )
+            return cached_path
+
+
+fallback_downloader: CachedDownloader = HuggingFaceCompatibleDownloader()
+downloader: CachedDownloader = HuggingFaceCompatibleDownloader(
+    base_url="https://cdn.carve.photos",
+    fb_downloader=fallback_downloader,
+    name="Carve CDN",
+)
+downloader._fallback_downloader = fallback_downloader
+
+
+
+
+
+

Global variables

+
+
var MODELS_CHECKSUMS
+
+

Model -> checksum dictionary

+
+
var MODELS_URLS
+
+

All data needed to build path relative to huggingface.co for model download

+
+
+
+
+

Functions

+
+
+def sha512_checksum_calc(file:Β pathlib.Path) ‑>Β str +
+
+

Calculates the SHA512 hash digest of a file on fs

+

Args

+
+
file : Path
+
Path to the file
+
+

Returns

+

SHA512 hash digest of a file.

+
+ +Expand source code + +
def sha512_checksum_calc(file: Path) -> str:
+    """
+    Calculates the SHA512 hash digest of a file on fs
+
+    Args:
+        file (Path): Path to the file
+
+    Returns:
+        SHA512 hash digest of a file.
+    """
+    dd = hashlib.sha512()
+    with file.open("rb") as f:
+        for chunk in iter(lambda: f.read(4096), b""):
+            dd.update(chunk)
+    return dd.hexdigest()
+
+
+
+
+
+

Classes

+
+
+class CachedDownloader +
+
+

Metaclass for models downloaders.

+
+ +Expand source code + +
class CachedDownloader:
+    """
+    Metaclass for models downloaders.
+    """
+
+    __metaclass__ = ABCMeta
+
+    @property
+    @abstractmethod
+    def name(self) -> str:
+        return self.__class__.__name__
+
+    @property
+    @abstractmethod
+    def fallback_downloader(self) -> Optional["CachedDownloader"]:
+        """
+        Property MAY be overriden in subclasses.
+        Used in case if subclass failed to download model. So preferred downloader SHOULD be placed higher in the hierarchy.
+        Less preferred downloader SHOULD be provided by this property.
+        """
+        pass
+
+    def download_model(self, file_name: str) -> Path:
+        """
+        Downloads model from the internet and saves it to the cache.
+
+        Behavior:
+            If model is already downloaded it will be loaded from the cache.
+
+            If model is already downloaded, but checksum is invalid, it will be downloaded again.
+
+            If model download failed, fallback downloader will be used.
+        """
+        try:
+            return self.download_model_base(file_name)
+        except BaseException as e:
+            if self.fallback_downloader is not None:
+                warnings.warn(
+                    f"Failed to download model from {self.name} downloader."
+                    f" Trying to download from {self.fallback_downloader.name} downloader."
+                )
+                return self.fallback_downloader.download_model(file_name)
+            else:
+                warnings.warn(
+                    f"Failed to download model from {self.name} downloader."
+                    f" No fallback downloader available."
+                )
+                raise e
+
+    @abstractmethod
+    def download_model_base(self, model_name: str) -> Path:
+        """
+        Download model from any source if not cached.
+        Returns:
+            pathlib.Path: Path to the downloaded model.
+        """
+
+    def __call__(self, model_name: str):
+        return self.download_model(model_name)
+
+

Subclasses

+ +

Instance variables

+
+
var fallback_downloader :Β Optional[CachedDownloader]
+
+

Property MAY be overriden in subclasses. +Used in case if subclass failed to download model. So preferred downloader SHOULD be placed higher in the hierarchy. +Less preferred downloader SHOULD be provided by this property.

+
+ +Expand source code + +
@property
+@abstractmethod
+def fallback_downloader(self) -> Optional["CachedDownloader"]:
+    """
+    Property MAY be overriden in subclasses.
+    Used in case if subclass failed to download model. So preferred downloader SHOULD be placed higher in the hierarchy.
+    Less preferred downloader SHOULD be provided by this property.
+    """
+    pass
+
+
+
var name :Β str
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def name(self) -> str:
+    return self.__class__.__name__
+
+
+
+

Methods

+
+
+def download_model(self, file_name:Β str) ‑>Β pathlib.Path +
+
+

Downloads model from the internet and saves it to the cache.

+

Behavior

+

If model is already downloaded it will be loaded from the cache.

+

If model is already downloaded, but checksum is invalid, it will be downloaded again.

+

If model download failed, fallback downloader will be used.

+
+ +Expand source code + +
def download_model(self, file_name: str) -> Path:
+    """
+    Downloads model from the internet and saves it to the cache.
+
+    Behavior:
+        If model is already downloaded it will be loaded from the cache.
+
+        If model is already downloaded, but checksum is invalid, it will be downloaded again.
+
+        If model download failed, fallback downloader will be used.
+    """
+    try:
+        return self.download_model_base(file_name)
+    except BaseException as e:
+        if self.fallback_downloader is not None:
+            warnings.warn(
+                f"Failed to download model from {self.name} downloader."
+                f" Trying to download from {self.fallback_downloader.name} downloader."
+            )
+            return self.fallback_downloader.download_model(file_name)
+        else:
+            warnings.warn(
+                f"Failed to download model from {self.name} downloader."
+                f" No fallback downloader available."
+            )
+            raise e
+
+
+
+def download_model_base(self, model_name:Β str) ‑>Β pathlib.Path +
+
+

Download model from any source if not cached.

+

Returns

+
+
pathlib.Path
+
Path to the downloaded model.
+
+
+ +Expand source code + +
@abstractmethod
+def download_model_base(self, model_name: str) -> Path:
+    """
+    Download model from any source if not cached.
+    Returns:
+        pathlib.Path: Path to the downloaded model.
+    """
+
+
+
+
+
+class HuggingFaceCompatibleDownloader +(name:Β strΒ =Β 'Huggingface.co', base_url:Β strΒ =Β 'https://huggingface.co', fb_downloader:Β Optional[ForwardRef('CachedDownloader')]Β =Β None) +
+
+

Downloader for models from HuggingFace Hub. +Private models are not supported.

+
+ +Expand source code + +
class HuggingFaceCompatibleDownloader(CachedDownloader, ABC):
+    """
+    Downloader for models from HuggingFace Hub.
+    Private models are not supported.
+    """
+
+    def __init__(
+        self,
+        name: str = "Huggingface.co",
+        base_url: str = "https://huggingface.co",
+        fb_downloader: Optional["CachedDownloader"] = None,
+    ):
+        self.cache_dir = checkpoints_dir
+        """SHOULD be same for all instances to prevent downloading same model multiple times
+        Points to ~/.cache/carvekit/checkpoints"""
+        self.base_url = base_url
+        """MUST be a base url with protocol and domain name to huggingface or another, compatible in terms of models downloading API source"""
+        self._name = name
+        self._fallback_downloader = fb_downloader
+
+    @property
+    def fallback_downloader(self) -> Optional["CachedDownloader"]:
+        return self._fallback_downloader
+
+    @property
+    def name(self):
+        return self._name
+
+    def check_for_existence(self, model_name: str) -> Optional[Path]:
+        """
+        Checks if model is already downloaded and cached. Verifies file integrity by checksum.
+        Returns:
+            Optional[pathlib.Path]: Path to the cached model if cached.
+        """
+        if model_name not in MODELS_URLS.keys():
+            raise FileNotFoundError("Unknown model!")
+        path = (
+            self.cache_dir
+            / MODELS_URLS[model_name]["repository"].split("/")[1]
+            / model_name
+        )
+
+        if not path.exists():
+            return None
+
+        if MODELS_CHECKSUMS[path.name] != sha512_checksum_calc(path):
+            warnings.warn(
+                f"Invalid checksum for model {path.name}. Downloading correct model!"
+            )
+            os.remove(path)
+            return None
+        return path
+
+    def download_model_base(self, model_name: str) -> Path:
+        cached_path = self.check_for_existence(model_name)
+        if cached_path is not None:
+            return cached_path
+        else:
+            cached_path = (
+                self.cache_dir
+                / MODELS_URLS[model_name]["repository"].split("/")[1]
+                / model_name
+            )
+            cached_path.parent.mkdir(parents=True, exist_ok=True)
+            url = MODELS_URLS[model_name]
+            hugging_face_url = f"{self.base_url}/{url['repository']}/resolve/{url['revision']}/{url['filename']}"
+
+            try:
+                r = requests.get(hugging_face_url, stream=True, timeout=10)
+                if r.status_code < 400:
+                    with open(cached_path, "wb") as f:
+                        r.raw.decode_content = True
+                        for chunk in tqdm.tqdm(
+                            r,
+                            desc="Downloading " + cached_path.name + " model",
+                            colour="blue",
+                        ):
+                            f.write(chunk)
+                else:
+                    if r.status_code == 404:
+                        raise FileNotFoundError(f"Model {model_name} not found!")
+                    else:
+                        raise ConnectionError(
+                            f"Error {r.status_code} while downloading model {model_name}!"
+                        )
+            except BaseException as e:
+                if cached_path.exists():
+                    os.remove(cached_path)
+                raise ConnectionError(
+                    f"Exception caught when downloading model! "
+                    f"Model name: {cached_path.name}. Exception: {str(e)}."
+                )
+            return cached_path
+
+

Ancestors

+ +

Instance variables

+
+
var base_url
+
+

MUST be a base url with protocol and domain name to huggingface or another, compatible in terms of models downloading API source

+
+
var cache_dir
+
+

SHOULD be same for all instances to prevent downloading same model multiple times +Points to ~/.cache/carvekit/checkpoints

+
+
var name
+
+
+
+ +Expand source code + +
@property
+def name(self):
+    return self._name
+
+
+
+

Methods

+
+
+def check_for_existence(self, model_name:Β str) ‑>Β Optional[pathlib.Path] +
+
+

Checks if model is already downloaded and cached. Verifies file integrity by checksum.

+

Returns

+
+
Optional[pathlib.Path]
+
Path to the cached model if cached.
+
+
+ +Expand source code + +
def check_for_existence(self, model_name: str) -> Optional[Path]:
+    """
+    Checks if model is already downloaded and cached. Verifies file integrity by checksum.
+    Returns:
+        Optional[pathlib.Path]: Path to the cached model if cached.
+    """
+    if model_name not in MODELS_URLS.keys():
+        raise FileNotFoundError("Unknown model!")
+    path = (
+        self.cache_dir
+        / MODELS_URLS[model_name]["repository"].split("/")[1]
+        / model_name
+    )
+
+    if not path.exists():
+        return None
+
+    if MODELS_CHECKSUMS[path.name] != sha512_checksum_calc(path):
+        warnings.warn(
+            f"Invalid checksum for model {path.name}. Downloading correct model!"
+        )
+        os.remove(path)
+        return None
+    return path
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/high.html b/docs/api/high.html new file mode 100644 index 0000000..ee982a1 --- /dev/null +++ b/docs/api/high.html @@ -0,0 +1,377 @@ + + + + + + +high API documentation + + + + + + + + + + + +
+
+
+

Module high

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+import warnings
+
+from carvekit.api.interface import Interface
+from carvekit.ml.wrap.fba_matting import FBAMatting
+from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
+from carvekit.ml.wrap.cascadepsp import CascadePSP
+from carvekit.ml.wrap.scene_classifier import SceneClassifier
+from carvekit.pipelines.preprocessing import AutoScene
+from carvekit.ml.wrap.u2net import U2NET
+from carvekit.pipelines.postprocessing import CasMattingMethod
+from carvekit.trimap.generator import TrimapGenerator
+
+
+class HiInterface(Interface):
+    def __init__(
+        self,
+        object_type: str = "auto",
+        batch_size_pre=5,
+        batch_size_seg=2,
+        batch_size_matting=1,
+        batch_size_refine=1,
+        device="cpu",
+        seg_mask_size=640,
+        matting_mask_size=2048,
+        refine_mask_size=900,
+        trimap_prob_threshold=231,
+        trimap_dilation=30,
+        trimap_erosion_iters=5,
+        fp16=False,
+    ):
+        """
+        Initializes High Level interface.
+
+        Args:
+            object_type (str, default=object): Interest object type. Can be "object" or "hairs-like".
+            matting_mask_size (int, default=2048):  The size of the input image for the matting neural network.
+            seg_mask_size (int, default=640): The size of the input image for the segmentation neural network.
+            batch_size_pre (int, default=5: Number of images processed per one preprocessing method call.
+            batch_size_seg (int, default=2): Number of images processed per one segmentation neural network call.
+            batch_size_matting (int, matting=1): Number of images processed per one matting neural network call.
+            device (Literal[cpu, cuda], default=cpu): Processing device
+            fp16 (bool, default=False): Use half precision. Reduce memory usage and increase speed.
+            .. CAUTION:: ⚠️ **Experimental support**
+            trimap_prob_threshold (int, default=231): Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied
+            trimap_dilation (int, default=30): The size of the offset radius from the object mask in pixels when forming an unknown area
+            trimap_erosion_iters (int, default=5): The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area
+            refine_mask_size (int, default=900): The size of the input image for the refinement neural network.
+            batch_size_refine (int, default=1): Number of images processed per one refinement neural network call.
+
+
+        .. NOTE::
+            1. Changing seg_mask_size may cause an `out-of-memory` error if the value is too large, and it may also
+            result in reduced precision. I do not recommend changing this value. You can change `matting_mask_size` in
+            range from `(1024 to 4096)` to improve object edge refining quality, but it will cause extra large RAM and
+            video memory consume. Also, you can change batch size to accelerate background removal, but it also causes
+            extra large video memory consume, if value is too big.
+            2. Changing `trimap_prob_threshold`, `trimap_kernel_size`, `trimap_erosion_iters` may improve object edge
+            refining quality.
+        """
+        preprocess_pipeline = None
+
+        if object_type == "object":
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+        elif object_type == "hairs-like":
+            self._segnet = U2NET(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+        elif object_type == "auto":
+            # Using Tracer by default,
+            # but it will dynamically switch to other if needed
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+            self._scene_classifier = SceneClassifier(
+                device=device, fp16=fp16, batch_size=batch_size_pre
+            )
+            preprocess_pipeline = AutoScene(scene_classifier=self._scene_classifier)
+
+        else:
+            warnings.warn(
+                f"Unknown object type: {object_type}. Using default object type: object"
+            )
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+
+        self._cascade_psp = CascadePSP(
+            device=device,
+            batch_size=batch_size_refine,
+            input_tensor_size=refine_mask_size,
+            fp16=fp16,
+        )
+        self._fba = FBAMatting(
+            batch_size=batch_size_matting,
+            device=device,
+            input_tensor_size=matting_mask_size,
+            fp16=fp16,
+        )
+        self._trimap_generator = TrimapGenerator(
+            prob_threshold=trimap_prob_threshold,
+            kernel_size=trimap_dilation,
+            erosion_iters=trimap_erosion_iters,
+        )
+        super(HiInterface, self).__init__(
+            pre_pipe=preprocess_pipeline,
+            seg_pipe=self._segnet,
+            post_pipe=CasMattingMethod(
+                refining_module=self._cascade_psp,
+                matting_module=self._fba,
+                trimap_generator=self._trimap_generator,
+                device=device,
+            ),
+            device=device,
+        )
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class HiInterface +(object_type:Β strΒ =Β 'auto', batch_size_pre=5, batch_size_seg=2, batch_size_matting=1, batch_size_refine=1, device='cpu', seg_mask_size=640, matting_mask_size=2048, refine_mask_size=900, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, fp16=False) +
+
+

Initializes High Level interface.

+

Args

+
+
object_type : str, default=object
+
Interest object type. Can be "object" or "hairs-like".
+
matting_mask_size : int, default=2048
+
The size of the input image for the matting neural network.
+
seg_mask_size : int, default=640
+
The size of the input image for the segmentation neural network.
+
batch_size_pre (int, default=5: Number of images processed per one preprocessing method call.
+
batch_size_seg : int, default=2
+
Number of images processed per one segmentation neural network call.
+
batch_size_matting : int, matting=1
+
Number of images processed per one matting neural network call.
+
device : Literal[cpu, cuda], default=cpu
+
Processing device
+
fp16 : bool, default=False
+
Use half precision. Reduce memory usage and increase speed.
+
+
+

Caution: ⚠️ Experimental support

+
+
+
trimap_prob_threshold : int, default=231
+
Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied
+
trimap_dilation : int, default=30
+
The size of the offset radius from the object mask in pixels when forming an unknown area
+
trimap_erosion_iters : int, default=5
+
The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area
+
refine_mask_size : int, default=900
+
The size of the input image for the refinement neural network.
+
batch_size_refine : int, default=1
+
Number of images processed per one refinement neural network call.
+
+
+

Note

+
    +
  1. Changing seg_mask_size may cause an out-of-memory error if the value is too large, and it may also +result in reduced precision. I do not recommend changing this value. You can change matting_mask_size in +range from (1024 to 4096) to improve object edge refining quality, but it will cause extra large RAM and +video memory consume. Also, you can change batch size to accelerate background removal, but it also causes +extra large video memory consume, if value is too big.
  2. +
  3. Changing trimap_prob_threshold, trimap_kernel_size, trimap_erosion_iters may improve object edge +refining quality.
  4. +
+
+
+ +Expand source code + +
class HiInterface(Interface):
+    def __init__(
+        self,
+        object_type: str = "auto",
+        batch_size_pre=5,
+        batch_size_seg=2,
+        batch_size_matting=1,
+        batch_size_refine=1,
+        device="cpu",
+        seg_mask_size=640,
+        matting_mask_size=2048,
+        refine_mask_size=900,
+        trimap_prob_threshold=231,
+        trimap_dilation=30,
+        trimap_erosion_iters=5,
+        fp16=False,
+    ):
+        """
+        Initializes High Level interface.
+
+        Args:
+            object_type (str, default=object): Interest object type. Can be "object" or "hairs-like".
+            matting_mask_size (int, default=2048):  The size of the input image for the matting neural network.
+            seg_mask_size (int, default=640): The size of the input image for the segmentation neural network.
+            batch_size_pre (int, default=5: Number of images processed per one preprocessing method call.
+            batch_size_seg (int, default=2): Number of images processed per one segmentation neural network call.
+            batch_size_matting (int, matting=1): Number of images processed per one matting neural network call.
+            device (Literal[cpu, cuda], default=cpu): Processing device
+            fp16 (bool, default=False): Use half precision. Reduce memory usage and increase speed.
+            .. CAUTION:: ⚠️ **Experimental support**
+            trimap_prob_threshold (int, default=231): Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied
+            trimap_dilation (int, default=30): The size of the offset radius from the object mask in pixels when forming an unknown area
+            trimap_erosion_iters (int, default=5): The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area
+            refine_mask_size (int, default=900): The size of the input image for the refinement neural network.
+            batch_size_refine (int, default=1): Number of images processed per one refinement neural network call.
+
+
+        .. NOTE::
+            1. Changing seg_mask_size may cause an `out-of-memory` error if the value is too large, and it may also
+            result in reduced precision. I do not recommend changing this value. You can change `matting_mask_size` in
+            range from `(1024 to 4096)` to improve object edge refining quality, but it will cause extra large RAM and
+            video memory consume. Also, you can change batch size to accelerate background removal, but it also causes
+            extra large video memory consume, if value is too big.
+            2. Changing `trimap_prob_threshold`, `trimap_kernel_size`, `trimap_erosion_iters` may improve object edge
+            refining quality.
+        """
+        preprocess_pipeline = None
+
+        if object_type == "object":
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+        elif object_type == "hairs-like":
+            self._segnet = U2NET(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+        elif object_type == "auto":
+            # Using Tracer by default,
+            # but it will dynamically switch to other if needed
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+            self._scene_classifier = SceneClassifier(
+                device=device, fp16=fp16, batch_size=batch_size_pre
+            )
+            preprocess_pipeline = AutoScene(scene_classifier=self._scene_classifier)
+
+        else:
+            warnings.warn(
+                f"Unknown object type: {object_type}. Using default object type: object"
+            )
+            self._segnet = TracerUniversalB7(
+                device=device,
+                batch_size=batch_size_seg,
+                input_image_size=seg_mask_size,
+                fp16=fp16,
+            )
+
+        self._cascade_psp = CascadePSP(
+            device=device,
+            batch_size=batch_size_refine,
+            input_tensor_size=refine_mask_size,
+            fp16=fp16,
+        )
+        self._fba = FBAMatting(
+            batch_size=batch_size_matting,
+            device=device,
+            input_tensor_size=matting_mask_size,
+            fp16=fp16,
+        )
+        self._trimap_generator = TrimapGenerator(
+            prob_threshold=trimap_prob_threshold,
+            kernel_size=trimap_dilation,
+            erosion_iters=trimap_erosion_iters,
+        )
+        super(HiInterface, self).__init__(
+            pre_pipe=preprocess_pipeline,
+            seg_pipe=self._segnet,
+            post_pipe=CasMattingMethod(
+                refining_module=self._cascade_psp,
+                matting_module=self._fba,
+                trimap_generator=self._trimap_generator,
+                device=device,
+            ),
+            device=device,
+        )
+
+

Ancestors

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/image_utils.html b/docs/api/image_utils.html new file mode 100644 index 0000000..c384714 --- /dev/null +++ b/docs/api/image_utils.html @@ -0,0 +1,510 @@ + + + + + + +image_utils API documentation + + + + + + + + + + + +
+
+
+

Module image_utils

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+    Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+    Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+    License: Apache License 2.0
+"""
+
+import pathlib
+from typing import Union, Any, Tuple
+
+import PIL.Image
+import numpy as np
+import torch
+
+ALLOWED_SUFFIXES = [".jpg", ".jpeg", ".bmp", ".png", ".webp"]
+
+
+def to_tensor(x: Any) -> torch.Tensor:
+    """
+    Returns a PIL.Image.Image as torch tensor without swap tensor dims.
+
+    Args:
+        x (PIL.Image.Image): image
+
+    Returns:
+        torch.Tensor: image as torch tensor
+    """
+    return torch.tensor(np.array(x, copy=True))
+
+
+def load_image(file: Union[str, pathlib.Path, PIL.Image.Image]) -> PIL.Image.Image:
+    """Returns a `PIL.Image.Image` class by string path or `pathlib.Path` or `PIL.Image.Image` instance
+
+    Args:
+        file (Union[str, pathlib.Path, PIL.Image.Image]): File path or `PIL.Image.Image` instance
+
+    Returns:
+        PIL.Image.Image: image instance loaded from `file` location
+
+    Raises:
+        ValueError: If file not exists or file is directory or file isn't an image or file is not correct PIL Image
+
+    """
+    if isinstance(file, str) and is_image_valid(pathlib.Path(file)):
+        return PIL.Image.open(file)
+    elif isinstance(file, PIL.Image.Image):
+        return file
+    elif isinstance(file, pathlib.Path) and is_image_valid(file):
+        return PIL.Image.open(str(file))
+    else:
+        raise ValueError("Unknown input file type")
+
+
+def convert_image(image: PIL.Image.Image, mode="RGB") -> PIL.Image.Image:
+    """Performs image conversion to correct color mode
+
+    Args:
+        image (PIL.Image.Image): `PIL.Image.Image` instance
+        mode (str, default=RGB): Color mode to convert
+
+    Returns:
+        PIL.Image.Image: converted image
+
+    Raises:
+        ValueError: If image hasn't convertable color mode, or it is too small
+    """
+    if is_image_valid(image):
+        return image.convert(mode)
+
+
+def is_image_valid(image: Union[pathlib.Path, PIL.Image.Image]) -> bool:
+    """This function performs image validation.
+
+    Args:
+        image (Union[pathlib.Path, PIL.Image.Image]): Path to the image or `PIL.Image.Image` instance being checked.
+
+    Returns:
+        bool: True if image is valid, False otherwise.
+
+    Raises:
+        ValueError: If file not a valid image path or image hasn't convertable color mode, or it is too small
+
+    """
+    if isinstance(image, pathlib.Path):
+        if not image.exists():
+            raise ValueError("File is not exists")
+        elif image.is_dir():
+            raise ValueError("File is a directory")
+        elif image.suffix.lower() not in ALLOWED_SUFFIXES:
+            raise ValueError(
+                f"Unsupported image format. Supported file formats: {', '.join(ALLOWED_SUFFIXES)}"
+            )
+    elif isinstance(image, PIL.Image.Image):
+        if not (image.size[0] > 32 and image.size[1] > 32):
+            raise ValueError("Image should be bigger then (32x32) pixels.")
+        elif image.mode not in [
+            "RGB",
+            "RGBA",
+            "L",
+        ]:
+            raise ValueError("Wrong image color mode.")
+    else:
+        raise ValueError("Unknown input file type")
+    return True
+
+
+def transparency_paste(
+    bg_img: PIL.Image.Image, fg_img: PIL.Image.Image, box=(0, 0)
+) -> PIL.Image.Image:
+    """
+    Inserts an image into another image while maintaining transparency.
+
+    Args:
+        bg_img (PIL.Image.Image): background image
+        fg_img (PIL.Image.Image): foreground image
+        box (tuple[int, int]): place to paste
+
+    Returns:
+        PIL.Image.Image: Background image with pasted foreground image at point or in the specified box
+    """
+    fg_img_trans = PIL.Image.new("RGBA", bg_img.size)
+    fg_img_trans.paste(fg_img, box, mask=fg_img)
+    new_img = PIL.Image.alpha_composite(bg_img, fg_img_trans)
+    return new_img
+
+
+def add_margin(
+    pil_img: PIL.Image.Image,
+    top: int,
+    right: int,
+    bottom: int,
+    left: int,
+    color: Tuple[int, int, int, int],
+) -> PIL.Image.Image:
+    """
+    Adds margin to the image.
+
+    Args:
+        pil_img (PIL.Image.Image): Image that needed to add margin.
+        top (int): pixels count at top side
+        right (int): pixels count at right side
+        bottom (int): pixels count at bottom side
+        left (int): pixels count at left side
+        color (Tuple[int, int, int, int]): color of margin
+
+    Returns:
+        PIL.Image.Image: Image with margin.
+    """
+    width, height = pil_img.size
+    new_width = width + right + left
+    new_height = height + top + bottom
+    # noinspection PyTypeChecker
+    result = PIL.Image.new(pil_img.mode, (new_width, new_height), color)
+    result.paste(pil_img, (left, top))
+    return result
+
+
+
+
+
+
+
+

Functions

+
+
+def add_margin(pil_img:Β PIL.Image.Image, top:Β int, right:Β int, bottom:Β int, left:Β int, color:Β Tuple[int,Β int,Β int,Β int]) ‑>Β PIL.Image.Image +
+
+

Adds margin to the image.

+

Args

+
+
pil_img : PIL.Image.Image
+
Image that needed to add margin.
+
top : int
+
pixels count at top side
+
right : int
+
pixels count at right side
+
bottom : int
+
pixels count at bottom side
+
left : int
+
pixels count at left side
+
color : Tuple[int, int, int, int]
+
color of margin
+
+

Returns

+
+
PIL.Image.Image
+
Image with margin.
+
+
+ +Expand source code + +
def add_margin(
+    pil_img: PIL.Image.Image,
+    top: int,
+    right: int,
+    bottom: int,
+    left: int,
+    color: Tuple[int, int, int, int],
+) -> PIL.Image.Image:
+    """
+    Adds margin to the image.
+
+    Args:
+        pil_img (PIL.Image.Image): Image that needed to add margin.
+        top (int): pixels count at top side
+        right (int): pixels count at right side
+        bottom (int): pixels count at bottom side
+        left (int): pixels count at left side
+        color (Tuple[int, int, int, int]): color of margin
+
+    Returns:
+        PIL.Image.Image: Image with margin.
+    """
+    width, height = pil_img.size
+    new_width = width + right + left
+    new_height = height + top + bottom
+    # noinspection PyTypeChecker
+    result = PIL.Image.new(pil_img.mode, (new_width, new_height), color)
+    result.paste(pil_img, (left, top))
+    return result
+
+
+
+def convert_image(image:Β PIL.Image.Image, mode='RGB') ‑>Β PIL.Image.Image +
+
+

Performs image conversion to correct color mode

+

Args

+
+
image : PIL.Image.Image
+
PIL.Image.Image instance
+
mode : str, default=RGB
+
Color mode to convert
+
+

Returns

+
+
PIL.Image.Image
+
converted image
+
+

Raises

+
+
ValueError
+
If image hasn't convertable color mode, or it is too small
+
+
+ +Expand source code + +
def convert_image(image: PIL.Image.Image, mode="RGB") -> PIL.Image.Image:
+    """Performs image conversion to correct color mode
+
+    Args:
+        image (PIL.Image.Image): `PIL.Image.Image` instance
+        mode (str, default=RGB): Color mode to convert
+
+    Returns:
+        PIL.Image.Image: converted image
+
+    Raises:
+        ValueError: If image hasn't convertable color mode, or it is too small
+    """
+    if is_image_valid(image):
+        return image.convert(mode)
+
+
+
+def is_image_valid(image:Β Union[pathlib.Path,Β PIL.Image.Image]) ‑>Β bool +
+
+

This function performs image validation.

+

Args

+
+
image : Union[pathlib.Path, PIL.Image.Image]
+
Path to the image or PIL.Image.Image instance being checked.
+
+

Returns

+
+
bool
+
True if image is valid, False otherwise.
+
+

Raises

+
+
ValueError
+
If file not a valid image path or image hasn't convertable color mode, or it is too small
+
+
+ +Expand source code + +
def is_image_valid(image: Union[pathlib.Path, PIL.Image.Image]) -> bool:
+    """This function performs image validation.
+
+    Args:
+        image (Union[pathlib.Path, PIL.Image.Image]): Path to the image or `PIL.Image.Image` instance being checked.
+
+    Returns:
+        bool: True if image is valid, False otherwise.
+
+    Raises:
+        ValueError: If file not a valid image path or image hasn't convertable color mode, or it is too small
+
+    """
+    if isinstance(image, pathlib.Path):
+        if not image.exists():
+            raise ValueError("File is not exists")
+        elif image.is_dir():
+            raise ValueError("File is a directory")
+        elif image.suffix.lower() not in ALLOWED_SUFFIXES:
+            raise ValueError(
+                f"Unsupported image format. Supported file formats: {', '.join(ALLOWED_SUFFIXES)}"
+            )
+    elif isinstance(image, PIL.Image.Image):
+        if not (image.size[0] > 32 and image.size[1] > 32):
+            raise ValueError("Image should be bigger then (32x32) pixels.")
+        elif image.mode not in [
+            "RGB",
+            "RGBA",
+            "L",
+        ]:
+            raise ValueError("Wrong image color mode.")
+    else:
+        raise ValueError("Unknown input file type")
+    return True
+
+
+
+def load_image(file:Β Union[str,Β pathlib.Path,Β PIL.Image.Image]) ‑>Β PIL.Image.Image +
+
+

Returns a PIL.Image.Image class by string path or pathlib.Path or PIL.Image.Image instance

+

Args

+
+
file : Union[str, pathlib.Path, PIL.Image.Image]
+
File path or PIL.Image.Image instance
+
+

Returns

+
+
PIL.Image.Image
+
image instance loaded from file location
+
+

Raises

+
+
ValueError
+
If file not exists or file is directory or file isn't an image or file is not correct PIL Image
+
+
+ +Expand source code + +
def load_image(file: Union[str, pathlib.Path, PIL.Image.Image]) -> PIL.Image.Image:
+    """Returns a `PIL.Image.Image` class by string path or `pathlib.Path` or `PIL.Image.Image` instance
+
+    Args:
+        file (Union[str, pathlib.Path, PIL.Image.Image]): File path or `PIL.Image.Image` instance
+
+    Returns:
+        PIL.Image.Image: image instance loaded from `file` location
+
+    Raises:
+        ValueError: If file not exists or file is directory or file isn't an image or file is not correct PIL Image
+
+    """
+    if isinstance(file, str) and is_image_valid(pathlib.Path(file)):
+        return PIL.Image.open(file)
+    elif isinstance(file, PIL.Image.Image):
+        return file
+    elif isinstance(file, pathlib.Path) and is_image_valid(file):
+        return PIL.Image.open(str(file))
+    else:
+        raise ValueError("Unknown input file type")
+
+
+
+def to_tensor(x:Β Any) ‑>Β torch.Tensor +
+
+

Returns a PIL.Image.Image as torch tensor without swap tensor dims.

+

Args

+
+
x : PIL.Image.Image
+
image
+
+

Returns

+
+
torch.Tensor
+
image as torch tensor
+
+
+ +Expand source code + +
def to_tensor(x: Any) -> torch.Tensor:
+    """
+    Returns a PIL.Image.Image as torch tensor without swap tensor dims.
+
+    Args:
+        x (PIL.Image.Image): image
+
+    Returns:
+        torch.Tensor: image as torch tensor
+    """
+    return torch.tensor(np.array(x, copy=True))
+
+
+
+def transparency_paste(bg_img:Β PIL.Image.Image, fg_img:Β PIL.Image.Image, box=(0, 0)) ‑>Β PIL.Image.Image +
+
+

Inserts an image into another image while maintaining transparency.

+

Args

+
+
bg_img : PIL.Image.Image
+
background image
+
fg_img : PIL.Image.Image
+
foreground image
+
box : tuple[int, int]
+
place to paste
+
+

Returns

+
+
PIL.Image.Image
+
Background image with pasted foreground image at point or in the specified box
+
+
+ +Expand source code + +
def transparency_paste(
+    bg_img: PIL.Image.Image, fg_img: PIL.Image.Image, box=(0, 0)
+) -> PIL.Image.Image:
+    """
+    Inserts an image into another image while maintaining transparency.
+
+    Args:
+        bg_img (PIL.Image.Image): background image
+        fg_img (PIL.Image.Image): foreground image
+        box (tuple[int, int]): place to paste
+
+    Returns:
+        PIL.Image.Image: Background image with pasted foreground image at point or in the specified box
+    """
+    fg_img_trans = PIL.Image.new("RGBA", bg_img.size)
+    fg_img_trans.paste(fg_img, box, mask=fg_img)
+    new_img = PIL.Image.alpha_composite(bg_img, fg_img_trans)
+    return new_img
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/interface.html b/docs/api/interface.html new file mode 100644 index 0000000..39c5799 --- /dev/null +++ b/docs/api/interface.html @@ -0,0 +1,234 @@ + + + + + + +interface API documentation + + + + + + + + + + + +
+
+
+

Module interface

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from pathlib import Path
+from typing import Union, List, Optional
+
+from PIL import Image
+
+from carvekit.ml.wrap.basnet import BASNET
+from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
+from carvekit.ml.wrap.u2net import U2NET
+from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
+from carvekit.pipelines.preprocessing import PreprocessingStub, AutoScene
+from carvekit.pipelines.postprocessing import MattingMethod, CasMattingMethod
+from carvekit.utils.image_utils import load_image
+from carvekit.utils.mask_utils import apply_mask
+from carvekit.utils.pool_utils import thread_pool_processing
+
+
+class Interface:
+    def __init__(
+        self,
+        seg_pipe: Optional[Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]],
+        pre_pipe: Optional[Union[PreprocessingStub, AutoScene]] = None,
+        post_pipe: Optional[Union[MattingMethod, CasMattingMethod]] = None,
+        device="cpu",
+    ):
+        """
+        Initializes an object for interacting with pipelines and other components of the CarveKit framework.
+
+        Args:
+            pre_pipe (Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]): Initialized pre-processing pipeline object
+            seg_pipe (Optional[Union[PreprocessingStub]]): Initialized segmentation network object
+            post_pipe (Optional[Union[MattingMethod]]): Initialized postprocessing pipeline object
+            device (Literal[cpu, cuda], default=cpu): The processing device that will be used to apply the masks to the images.
+        """
+        self.device = device
+        self.preprocessing_pipeline = pre_pipe
+        self.segmentation_pipeline = seg_pipe
+        self.postprocessing_pipeline = post_pipe
+
+    def __call__(
+        self, images: List[Union[str, Path, Image.Image]]
+    ) -> List[Image.Image]:
+        """
+        Removes the background from the specified images.
+
+        Args:
+            images: list of input images
+
+        Returns:
+            List of images without background as PIL.Image.Image instances
+        """
+        if self.segmentation_pipeline is None:
+            raise ValueError(
+                "Segmentation pipeline is not initialized."
+                "Override the class or pass the pipeline to the constructor."
+            )
+        images = thread_pool_processing(load_image, images)
+        if self.preprocessing_pipeline is not None:
+            masks: List[Image.Image] = self.preprocessing_pipeline(
+                interface=self, images=images
+            )
+        else:
+            masks: List[Image.Image] = self.segmentation_pipeline(images=images)
+
+        if self.postprocessing_pipeline is not None:
+            images: List[Image.Image] = self.postprocessing_pipeline(
+                images=images, masks=masks
+            )
+        else:
+            images = list(
+                map(
+                    lambda x: apply_mask(
+                        image=images[x], mask=masks[x], device=self.device
+                    ),
+                    range(len(images)),
+                )
+            )
+        return images
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class Interface +(seg_pipe:Β Union[U2NET,Β BASNET,Β DeepLabV3,Β TracerUniversalB7,Β ForwardRef(None)], pre_pipe:Β Union[PreprocessingStub,Β AutoScene,Β ForwardRef(None)]Β =Β None, post_pipe:Β Union[MattingMethod,Β CasMattingMethod,Β ForwardRef(None)]Β =Β None, device='cpu') +
+
+

Initializes an object for interacting with pipelines and other components of the CarveKit framework.

+

Args

+
+
pre_pipe : Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]
+
Initialized pre-processing pipeline object
+
seg_pipe : Optional[Union[PreprocessingStub]]
+
Initialized segmentation network object
+
post_pipe : Optional[Union[MattingMethod]]
+
Initialized postprocessing pipeline object
+
device : Literal[cpu, cuda], default=cpu
+
The processing device that will be used to apply the masks to the images.
+
+
+ +Expand source code + +
class Interface:
+    def __init__(
+        self,
+        seg_pipe: Optional[Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]],
+        pre_pipe: Optional[Union[PreprocessingStub, AutoScene]] = None,
+        post_pipe: Optional[Union[MattingMethod, CasMattingMethod]] = None,
+        device="cpu",
+    ):
+        """
+        Initializes an object for interacting with pipelines and other components of the CarveKit framework.
+
+        Args:
+            pre_pipe (Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]): Initialized pre-processing pipeline object
+            seg_pipe (Optional[Union[PreprocessingStub]]): Initialized segmentation network object
+            post_pipe (Optional[Union[MattingMethod]]): Initialized postprocessing pipeline object
+            device (Literal[cpu, cuda], default=cpu): The processing device that will be used to apply the masks to the images.
+        """
+        self.device = device
+        self.preprocessing_pipeline = pre_pipe
+        self.segmentation_pipeline = seg_pipe
+        self.postprocessing_pipeline = post_pipe
+
+    def __call__(
+        self, images: List[Union[str, Path, Image.Image]]
+    ) -> List[Image.Image]:
+        """
+        Removes the background from the specified images.
+
+        Args:
+            images: list of input images
+
+        Returns:
+            List of images without background as PIL.Image.Image instances
+        """
+        if self.segmentation_pipeline is None:
+            raise ValueError(
+                "Segmentation pipeline is not initialized."
+                "Override the class or pass the pipeline to the constructor."
+            )
+        images = thread_pool_processing(load_image, images)
+        if self.preprocessing_pipeline is not None:
+            masks: List[Image.Image] = self.preprocessing_pipeline(
+                interface=self, images=images
+            )
+        else:
+            masks: List[Image.Image] = self.segmentation_pipeline(images=images)
+
+        if self.postprocessing_pipeline is not None:
+            images: List[Image.Image] = self.postprocessing_pipeline(
+                images=images, masks=masks
+            )
+        else:
+            images = list(
+                map(
+                    lambda x: apply_mask(
+                        image=images[x], mask=masks[x], device=self.device
+                    ),
+                    range(len(images)),
+                )
+            )
+        return images
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/matting.html b/docs/api/matting.html new file mode 100644 index 0000000..77ca37f --- /dev/null +++ b/docs/api/matting.html @@ -0,0 +1,221 @@ + + + + + + +matting API documentation + + + + + + + + + + + +
+
+
+

Module matting

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from carvekit.ml.wrap.fba_matting import FBAMatting
+from typing import Union, List
+from PIL import Image
+from pathlib import Path
+from carvekit.trimap.cv_gen import CV2TrimapGenerator
+from carvekit.trimap.generator import TrimapGenerator
+from carvekit.utils.mask_utils import apply_mask
+from carvekit.utils.pool_utils import thread_pool_processing
+from carvekit.utils.image_utils import load_image, convert_image
+
+__all__ = ["MattingMethod"]
+
+
+class MattingMethod:
+    """
+    Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap.
+    Neural network for matting performs accurate object edge detection by using a special map called trimap,
+    with unknown area that we scan for boundary, already known general object area and the background."""
+
+    def __init__(
+        self,
+        matting_module: Union[FBAMatting],
+        trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
+        device="cpu",
+    ):
+        """
+        Initializes Matting Method class.
+
+        Args:
+        - `matting_module`: Initialized matting neural network class
+        - `trimap_generator`: Initialized trimap generator class
+        - `device`: Processing device used for applying mask to image
+        """
+        self.device = device
+        self.matting_module = matting_module
+        self.trimap_generator = trimap_generator
+
+    def __call__(
+        self,
+        images: List[Union[str, Path, Image.Image]],
+        masks: List[Union[str, Path, Image.Image]],
+    ):
+        """
+        Passes data through apply_mask function
+
+        Args:
+        - `images`: list of images
+        - `masks`: list pf masks
+
+        Returns:
+        list of images
+        """
+        if len(images) != len(masks):
+            raise ValueError("Images and Masks lists should have same length!")
+        images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
+        masks = thread_pool_processing(
+            lambda x: convert_image(load_image(x), mode="L"), masks
+        )
+        trimaps = thread_pool_processing(
+            lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]),
+            range(len(images)),
+        )
+        alpha = self.matting_module(images=images, trimaps=trimaps)
+        return list(
+            map(
+                lambda x: apply_mask(
+                    image=images[x], mask=alpha[x], device=self.device
+                ),
+                range(len(images)),
+            )
+        )
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class MattingMethod +(matting_module:Β FBAMatting, trimap_generator:Β Union[TrimapGenerator,Β CV2TrimapGenerator], device='cpu') +
+
+

Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap. +Neural network for matting performs accurate object edge detection by using a special map called trimap, +with unknown area that we scan for boundary, already known general object area and the background.

+

Initializes Matting Method class.

+

Args: +- matting_module: Initialized matting neural network class +- trimap_generator: Initialized trimap generator class +- device: Processing device used for applying mask to image

+
+ +Expand source code + +
class MattingMethod:
+    """
+    Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap.
+    Neural network for matting performs accurate object edge detection by using a special map called trimap,
+    with unknown area that we scan for boundary, already known general object area and the background."""
+
+    def __init__(
+        self,
+        matting_module: Union[FBAMatting],
+        trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
+        device="cpu",
+    ):
+        """
+        Initializes Matting Method class.
+
+        Args:
+        - `matting_module`: Initialized matting neural network class
+        - `trimap_generator`: Initialized trimap generator class
+        - `device`: Processing device used for applying mask to image
+        """
+        self.device = device
+        self.matting_module = matting_module
+        self.trimap_generator = trimap_generator
+
+    def __call__(
+        self,
+        images: List[Union[str, Path, Image.Image]],
+        masks: List[Union[str, Path, Image.Image]],
+    ):
+        """
+        Passes data through apply_mask function
+
+        Args:
+        - `images`: list of images
+        - `masks`: list pf masks
+
+        Returns:
+        list of images
+        """
+        if len(images) != len(masks):
+            raise ValueError("Images and Masks lists should have same length!")
+        images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
+        masks = thread_pool_processing(
+            lambda x: convert_image(load_image(x), mode="L"), masks
+        )
+        trimaps = thread_pool_processing(
+            lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]),
+            range(len(images)),
+        )
+        alpha = self.matting_module(images=images, trimaps=trimaps)
+        return list(
+            map(
+                lambda x: apply_mask(
+                    image=images[x], mask=alpha[x], device=self.device
+                ),
+                range(len(images)),
+            )
+        )
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/models_loc.html b/docs/api/models_loc.html new file mode 100644 index 0000000..34ecdec --- /dev/null +++ b/docs/api/models_loc.html @@ -0,0 +1,412 @@ + + + + + + +models_loc API documentation + + + + + + + + + + + +
+
+
+

Module models_loc

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+import pathlib
+from carvekit.ml.files import checkpoints_dir
+from carvekit.utils.download_models import downloader
+
+
+def u2net_full_pretrained() -> pathlib.Path:
+    """Returns u2net pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("u2net.pth")
+
+
+def basnet_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("basnet.pth")
+
+
+def deeplab_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("deeplab.pth")
+
+
+def fba_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("fba_matting.pth")
+
+
+def tracer_b7_pretrained() -> pathlib.Path:
+    """Returns TRACER with EfficientNet v1 b7 encoder pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("tracer_b7.pth")
+
+
+def scene_classifier_pretrained() -> pathlib.Path:
+    """Returns scene classifier pretrained model location
+    This model is used to classify scenes into 3 categories: hard, soft, digital
+
+    hard - scenes with hard edges, such as objects, buildings, etc.
+    soft - scenes with soft edges, such as portraits, hairs, animal, etc.
+    digital - digital scenes, such as screenshots, graphics, etc.
+
+    more info: https://huggingface.co/Carve/scene_classifier
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("scene_classifier.pth")
+
+
+def yolov4_coco_pretrained() -> pathlib.Path:
+    """Returns yolov4 classifier pretrained model location
+    This model is used to classify objects in images.
+
+    Training dataset: COCO 2017
+    Training classes: 80
+
+    It's a modified version of the original model from https://github.com/Tianxiaomo/pytorch-YOLOv4 (pytorch)
+    We have only added coco classnames to the model.
+
+    Returns:
+        pathlib.Path to model location
+    """
+    return downloader("yolov4_coco_with_classes.pth")
+
+
+def cascadepsp_pretrained() -> pathlib.Path:
+    """Returns cascade psp pretrained model location
+    This model is used to refine segmentation masks.
+
+    Training dataset: MSRA-10K, DUT-OMRON, ECSSD and FSS-1000
+    more info: https://huggingface.co/Carve/cascadepsp
+
+    Returns:
+        pathlib.Path to model location
+    """
+    return downloader("cascadepsp.pth")
+
+
+def download_all():
+    u2net_full_pretrained()
+    fba_pretrained()
+    deeplab_pretrained()
+    basnet_pretrained()
+    tracer_b7_pretrained()
+    scene_classifier_pretrained()
+    yolov4_coco_pretrained()
+    cascadepsp_pretrained()
+
+
+
+
+
+
+
+

Functions

+
+
+def basnet_pretrained() ‑>Β pathlib.Path +
+
+

Returns basnet pretrained model location

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def basnet_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("basnet.pth")
+
+
+
+def cascadepsp_pretrained() ‑>Β pathlib.Path +
+
+

Returns cascade psp pretrained model location +This model is used to refine segmentation masks.

+

Training dataset: MSRA-10K, DUT-OMRON, ECSSD and FSS-1000 +more info: https://huggingface.co/Carve/cascadepsp

+

Returns

+

pathlib.Path to model location

+
+ +Expand source code + +
def cascadepsp_pretrained() -> pathlib.Path:
+    """Returns cascade psp pretrained model location
+    This model is used to refine segmentation masks.
+
+    Training dataset: MSRA-10K, DUT-OMRON, ECSSD and FSS-1000
+    more info: https://huggingface.co/Carve/cascadepsp
+
+    Returns:
+        pathlib.Path to model location
+    """
+    return downloader("cascadepsp.pth")
+
+
+
+def deeplab_pretrained() ‑>Β pathlib.Path +
+
+

Returns basnet pretrained model location

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def deeplab_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("deeplab.pth")
+
+
+
+def download_all() +
+
+
+
+ +Expand source code + +
def download_all():
+    u2net_full_pretrained()
+    fba_pretrained()
+    deeplab_pretrained()
+    basnet_pretrained()
+    tracer_b7_pretrained()
+    scene_classifier_pretrained()
+    yolov4_coco_pretrained()
+    cascadepsp_pretrained()
+
+
+
+def fba_pretrained() ‑>Β pathlib.Path +
+
+

Returns basnet pretrained model location

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def fba_pretrained() -> pathlib.Path:
+    """Returns basnet pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("fba_matting.pth")
+
+
+
+def scene_classifier_pretrained() ‑>Β pathlib.Path +
+
+

Returns scene classifier pretrained model location +This model is used to classify scenes into 3 categories: hard, soft, digital

+

hard - scenes with hard edges, such as objects, buildings, etc. +soft - scenes with soft edges, such as portraits, hairs, animal, etc. +digital - digital scenes, such as screenshots, graphics, etc.

+

more info: https://huggingface.co/Carve/scene_classifier

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def scene_classifier_pretrained() -> pathlib.Path:
+    """Returns scene classifier pretrained model location
+    This model is used to classify scenes into 3 categories: hard, soft, digital
+
+    hard - scenes with hard edges, such as objects, buildings, etc.
+    soft - scenes with soft edges, such as portraits, hairs, animal, etc.
+    digital - digital scenes, such as screenshots, graphics, etc.
+
+    more info: https://huggingface.co/Carve/scene_classifier
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("scene_classifier.pth")
+
+
+
+def tracer_b7_pretrained() ‑>Β pathlib.Path +
+
+

Returns TRACER with EfficientNet v1 b7 encoder pretrained model location

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def tracer_b7_pretrained() -> pathlib.Path:
+    """Returns TRACER with EfficientNet v1 b7 encoder pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("tracer_b7.pth")
+
+
+
+def u2net_full_pretrained() ‑>Β pathlib.Path +
+
+

Returns u2net pretrained model location

+

Returns

+
+
pathlib.Path
+
model location
+
+
+ +Expand source code + +
def u2net_full_pretrained() -> pathlib.Path:
+    """Returns u2net pretrained model location
+
+    Returns:
+        pathlib.Path: model location
+    """
+    return downloader("u2net.pth")
+
+
+
+def yolov4_coco_pretrained() ‑>Β pathlib.Path +
+
+

Returns yolov4 classifier pretrained model location +This model is used to classify objects in images.

+

Training dataset: COCO 2017 +Training classes: 80

+

It's a modified version of the original model from https://github.com/Tianxiaomo/pytorch-YOLOv4 (pytorch) +We have only added coco classnames to the model.

+

Returns

+

pathlib.Path to model location

+
+ +Expand source code + +
def yolov4_coco_pretrained() -> pathlib.Path:
+    """Returns yolov4 classifier pretrained model location
+    This model is used to classify objects in images.
+
+    Training dataset: COCO 2017
+    Training classes: 80
+
+    It's a modified version of the original model from https://github.com/Tianxiaomo/pytorch-YOLOv4 (pytorch)
+    We have only added coco classnames to the model.
+
+    Returns:
+        pathlib.Path to model location
+    """
+    return downloader("yolov4_coco_with_classes.pth")
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/scene_classifier.html b/docs/api/scene_classifier.html new file mode 100644 index 0000000..6936110 --- /dev/null +++ b/docs/api/scene_classifier.html @@ -0,0 +1,454 @@ + + + + + + +scene_classifier API documentation + + + + + + + + + + + +
+
+
+

Module scene_classifier

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import pathlib
+
+import PIL.Image
+import torch
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+from typing import List, Union, Tuple
+from torch.autograd import Variable
+
+from carvekit.ml.files.models_loc import scene_classifier_pretrained
+from carvekit.utils.image_utils import load_image, convert_image
+from carvekit.utils.models_utils import get_precision_autocast, cast_network
+from carvekit.utils.pool_utils import thread_pool_processing, batch_generator
+
+__all__ = ["SceneClassifier"]
+
+
+class SceneClassifier:
+    """
+    SceneClassifier model interface
+
+    Description:
+        Performs a primary analysis of the image in order to select the necessary method for removing the background.
+        The choice is made by classifying the scene type.
+
+        The output can be the following types:
+        - hard
+        - soft
+        - digital
+
+    """
+
+    def __init__(
+        self,
+        topk: int = 1,
+        device="cpu",
+        batch_size: int = 4,
+        fp16: bool = False,
+        model_path: Union[str, pathlib.Path] = None,
+    ):
+        """
+        Initialize the Scene Classifier.
+
+        Args:
+            topk: number of top classes to return
+            device: processing device
+            batch_size: the number of images that the neural network processes in one run
+            fp16: use fp16 precision
+
+        """
+        if model_path is None:
+            model_path = scene_classifier_pretrained()
+        self.topk = topk
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+
+        self.transform = transforms.Compose(
+            [
+                transforms.Resize(256),
+                transforms.CenterCrop(224),
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+        state_dict = torch.load(model_path, map_location=device)
+        self.model = state_dict["model"]
+        self.class_to_idx = state_dict["class_to_idx"]
+        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
+        self.model.to(device)
+        self.model.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data: input image
+
+        Returns:
+            input for neural network
+
+        """
+
+        return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+    def data_postprocessing(self, data: torch.Tensor) -> Tuple[List[str], List[float]]:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+
+        Returns:
+            Top-k class of scene type, probability of these classes
+
+        """
+        ps = F.softmax(data.float(), dim=0)
+        topk = ps.cpu().topk(self.topk)
+
+        probs, classes = (e.data.numpy().squeeze().tolist() for e in topk)
+        if isinstance(classes, int):
+            classes = [classes]
+            probs = [probs]
+        return list(map(lambda x: self.idx_to_class[x], classes)), probs
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> Tuple[List[str], List[float]]:
+        """
+        Passes input images though neural network and returns class predictions.
+
+        Args:
+            images: input images
+
+        Returns:
+            Top-k class of scene type, probability of these classes for every passed image
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self.model, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = torch.vstack(
+                    thread_pool_processing(self.data_preprocessing, converted_images)
+                )
+                with torch.no_grad():
+                    batches = Variable(batches).to(self.device)
+                    masks = self.model.forward(batches)
+                    masks_cpu = masks.cpu()
+                    del batches, masks
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(masks_cpu[x]),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+
+        return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class SceneClassifier +(topk:Β intΒ =Β 1, device='cpu', batch_size:Β intΒ =Β 4, fp16:Β boolΒ =Β False, model_path:Β Union[str,Β pathlib.Path]Β =Β None) +
+
+

SceneClassifier model interface

+

Description

+

Performs a primary analysis of the image in order to select the necessary method for removing the background. +The choice is made by classifying the scene type.

+

The output can be the following types: +- hard +- soft +- digital

+

Initialize the Scene Classifier.

+

Args

+
+
topk
+
number of top classes to return
+
device
+
processing device
+
batch_size
+
the number of images that the neural network processes in one run
+
fp16
+
use fp16 precision
+
+
+ +Expand source code + +
class SceneClassifier:
+    """
+    SceneClassifier model interface
+
+    Description:
+        Performs a primary analysis of the image in order to select the necessary method for removing the background.
+        The choice is made by classifying the scene type.
+
+        The output can be the following types:
+        - hard
+        - soft
+        - digital
+
+    """
+
+    def __init__(
+        self,
+        topk: int = 1,
+        device="cpu",
+        batch_size: int = 4,
+        fp16: bool = False,
+        model_path: Union[str, pathlib.Path] = None,
+    ):
+        """
+        Initialize the Scene Classifier.
+
+        Args:
+            topk: number of top classes to return
+            device: processing device
+            batch_size: the number of images that the neural network processes in one run
+            fp16: use fp16 precision
+
+        """
+        if model_path is None:
+            model_path = scene_classifier_pretrained()
+        self.topk = topk
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+
+        self.transform = transforms.Compose(
+            [
+                transforms.Resize(256),
+                transforms.CenterCrop(224),
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+        state_dict = torch.load(model_path, map_location=device)
+        self.model = state_dict["model"]
+        self.class_to_idx = state_dict["class_to_idx"]
+        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
+        self.model.to(device)
+        self.model.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data: input image
+
+        Returns:
+            input for neural network
+
+        """
+
+        return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+    def data_postprocessing(self, data: torch.Tensor) -> Tuple[List[str], List[float]]:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data: output data from neural network
+
+        Returns:
+            Top-k class of scene type, probability of these classes
+
+        """
+        ps = F.softmax(data.float(), dim=0)
+        topk = ps.cpu().topk(self.topk)
+
+        probs, classes = (e.data.numpy().squeeze().tolist() for e in topk)
+        if isinstance(classes, int):
+            classes = [classes]
+            probs = [probs]
+        return list(map(lambda x: self.idx_to_class[x], classes)), probs
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> Tuple[List[str], List[float]]:
+        """
+        Passes input images though neural network and returns class predictions.
+
+        Args:
+            images: input images
+
+        Returns:
+            Top-k class of scene type, probability of these classes for every passed image
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self.model, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = torch.vstack(
+                    thread_pool_processing(self.data_preprocessing, converted_images)
+                )
+                with torch.no_grad():
+                    batches = Variable(batches).to(self.device)
+                    masks = self.model.forward(batches)
+                    masks_cpu = masks.cpu()
+                    del batches, masks
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(masks_cpu[x]),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+
+        return collect_masks
+
+

Methods

+
+
+def data_postprocessing(self, data:Β torch.Tensor) ‑>Β Tuple[List[str],Β List[float]] +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data
+
output data from neural network
+
+

Returns

+

Top-k class of scene type, probability of these classes

+
+ +Expand source code + +
def data_postprocessing(self, data: torch.Tensor) -> Tuple[List[str], List[float]]:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data: output data from neural network
+
+    Returns:
+        Top-k class of scene type, probability of these classes
+
+    """
+    ps = F.softmax(data.float(), dim=0)
+    topk = ps.cpu().topk(self.topk)
+
+    probs, classes = (e.data.numpy().squeeze().tolist() for e in topk)
+    if isinstance(classes, int):
+        classes = [classes]
+        probs = [probs]
+    return list(map(lambda x: self.idx_to_class[x], classes)), probs
+
+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.FloatTensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data
+
input image
+
+

Returns

+

input for neural network

+
+ +Expand source code + +
def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data: input image
+
+    Returns:
+        input for neural network
+
+    """
+
+    return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/stub.html b/docs/api/stub.html new file mode 100644 index 0000000..8b15793 --- /dev/null +++ b/docs/api/stub.html @@ -0,0 +1,122 @@ + + + + + + +stub API documentation + + + + + + + + + + + +
+
+
+

Module stub

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+from pathlib import Path
+from typing import Union, List
+
+from PIL import Image
+
+__all__ = ["PreprocessingStub"]
+
+
+class PreprocessingStub:
+    """Stub for future preprocessing methods"""
+
+    def __call__(self, interface, images: List[Union[str, Path, Image.Image]]):
+        """
+        Passes data though `interface.segmentation_pipeline()` method
+
+        Args:
+        - `interface`: Interface instance
+        - `images`: list of images
+
+        Returns:
+            the result of passing data through segmentation_pipeline method of interface
+        """
+        return interface.segmentation_pipeline(images=images)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class PreprocessingStub +
+
+

Stub for future preprocessing methods

+
+ +Expand source code + +
class PreprocessingStub:
+    """Stub for future preprocessing methods"""
+
+    def __call__(self, interface, images: List[Union[str, Path, Image.Image]]):
+        """
+        Passes data though `interface.segmentation_pipeline()` method
+
+        Args:
+        - `interface`: Interface instance
+        - `images`: list of images
+
+        Returns:
+            the result of passing data through segmentation_pipeline method of interface
+        """
+        return interface.segmentation_pipeline(images=images)
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/test_trimap.html b/docs/api/test_trimap.html new file mode 100644 index 0000000..339bf27 --- /dev/null +++ b/docs/api/test_trimap.html @@ -0,0 +1,172 @@ + + + + + + +test_trimap API documentation + + + + + + + + + + + +
+
+
+

Module test_trimap

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool

+

Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].

+

License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+
+License: Apache License 2.0
+"""
+import PIL.Image
+import pytest
+
+from carvekit.trimap.add_ops import prob_as_unknown_area
+
+
+def test_trimap_generator(trimap_instance, image_mask, image_pil):
+    te = trimap_instance()
+    assert isinstance(te(image_pil, image_mask), PIL.Image.Image)
+    assert isinstance(
+        te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("L", (512, 512))),
+        PIL.Image.Image,
+    )
+    assert isinstance(
+        te(
+            PIL.Image.new("RGB", (512, 512), color=(255, 255, 255)),
+            PIL.Image.new("L", (512, 512), color=255),
+        ),
+        PIL.Image.Image,
+    )
+    with pytest.raises(ValueError):
+        te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("RGB", (512, 512)))
+    with pytest.raises(ValueError):
+        te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("RGB", (512, 512)))
+
+
+def test_cv2_generator(cv2_trimap_instance, image_pil, image_mask):
+    cv2trimapgen = cv2_trimap_instance()
+    assert isinstance(cv2trimapgen(image_pil, image_mask), PIL.Image.Image)
+    with pytest.raises(ValueError):
+        cv2trimapgen(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("RGB", (512, 512)))
+    with pytest.raises(ValueError):
+        cv2trimapgen(PIL.Image.new("L", (256, 256)), PIL.Image.new("L", (512, 512)))
+
+
+def test_prob_as_unknown_area(image_pil, image_mask):
+    with pytest.raises(ValueError):
+        prob_as_unknown_area(image_pil, image_mask)
+
+
+
+
+
+
+
+

Functions

+
+
+def test_cv2_generator(cv2_trimap_instance, image_pil, image_mask) +
+
+
+
+ +Expand source code + +
def test_cv2_generator(cv2_trimap_instance, image_pil, image_mask):
+    cv2trimapgen = cv2_trimap_instance()
+    assert isinstance(cv2trimapgen(image_pil, image_mask), PIL.Image.Image)
+    with pytest.raises(ValueError):
+        cv2trimapgen(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("RGB", (512, 512)))
+    with pytest.raises(ValueError):
+        cv2trimapgen(PIL.Image.new("L", (256, 256)), PIL.Image.new("L", (512, 512)))
+
+
+
+def test_prob_as_unknown_area(image_pil, image_mask) +
+
+
+
+ +Expand source code + +
def test_prob_as_unknown_area(image_pil, image_mask):
+    with pytest.raises(ValueError):
+        prob_as_unknown_area(image_pil, image_mask)
+
+
+
+def test_trimap_generator(trimap_instance, image_mask, image_pil) +
+
+
+
+ +Expand source code + +
def test_trimap_generator(trimap_instance, image_mask, image_pil):
+    te = trimap_instance()
+    assert isinstance(te(image_pil, image_mask), PIL.Image.Image)
+    assert isinstance(
+        te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("L", (512, 512))),
+        PIL.Image.Image,
+    )
+    assert isinstance(
+        te(
+            PIL.Image.new("RGB", (512, 512), color=(255, 255, 255)),
+            PIL.Image.new("L", (512, 512), color=255),
+        ),
+        PIL.Image.Image,
+    )
+    with pytest.raises(ValueError):
+        te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("RGB", (512, 512)))
+    with pytest.raises(ValueError):
+        te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("RGB", (512, 512)))
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/tracer_b7.html b/docs/api/tracer_b7.html new file mode 100644 index 0000000..f104f77 --- /dev/null +++ b/docs/api/tracer_b7.html @@ -0,0 +1,487 @@ + + + + + + +tracer_b7 API documentation + + + + + + + + + + + +
+
+
+

Module tracer_b7

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import pathlib
+from typing import List, Union
+
+import PIL.Image
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from PIL import Image
+
+from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7
+from carvekit.ml.arch.tracerb7.tracer import TracerDecoder
+from carvekit.ml.files.models_loc import tracer_b7_pretrained
+from carvekit.utils.image_utils import load_image, convert_image
+from carvekit.utils.models_utils import get_precision_autocast, cast_network
+from carvekit.utils.pool_utils import thread_pool_processing, batch_generator
+
+__all__ = ["TracerUniversalB7"]
+
+
+class TracerUniversalB7(TracerDecoder):
+    """TRACER B7 model interface"""
+
+    def __init__(
+        self,
+        device="cpu",
+        input_image_size: Union[List[int], int] = 640,
+        batch_size: int = 4,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+        model_path: Union[str, pathlib.Path] = None,
+    ):
+        """
+        Initialize the TRACER model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (Union[List[int], int], default=640): input image size
+            batch_size(int, default=4): the number of images that the neural network processes in one run
+            load_pretrained(bool, default=True): loading pretrained model
+            fp16 (bool, default=False): use fp16 precision
+            model_path (Union[str, pathlib.Path], default=None): path to the model
+            .. note:: REDO
+        """
+        if model_path is None:
+            model_path = tracer_b7_pretrained()
+        super(TracerUniversalB7, self).__init__(
+            encoder=EfficientEncoderB7(),
+            rfb_channel=[32, 64, 128],
+            features_channels=[48, 80, 224, 640],
+        )
+
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+
+        self.transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Resize(self.input_image_size),
+                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+            ]
+        )
+        self.to(device)
+        if load_pretrained:
+            # TODO remove edge detector from weights. It doesn't work well with this model!
+            self.load_state_dict(
+                torch.load(model_path, map_location=self.device), strict=False
+            )
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.FloatTensor: input for neural network
+
+        """
+
+        return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask
+
+        """
+        output = (data.type(torch.FloatTensor).detach().cpu().numpy() * 255.0).astype(
+            np.uint8
+        )
+        output = output.squeeze(0)
+        mask = Image.fromarray(output).convert("L")
+        mask = mask.resize(original_image.size, resample=Image.BILINEAR)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = torch.vstack(
+                    thread_pool_processing(self.data_preprocessing, converted_images)
+                )
+                with torch.no_grad():
+                    batches = batches.to(self.device)
+                    masks = super(TracerDecoder, self).__call__(batches)
+                    masks_cpu = masks.cpu()
+                    del batches, masks
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(
+                        masks_cpu[x], converted_images[x]
+                    ),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+
+        return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class TracerUniversalB7 +(device='cpu', input_image_size:Β Union[List[int],Β int]Β =Β 640, batch_size:Β intΒ =Β 4, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False, model_path:Β Union[str,Β pathlib.Path]Β =Β None) +
+
+

TRACER B7 model interface

+

Initialize the TRACER model

+

Args

+
+
device : Literal[cpu, cuda], default=cpu
+
processing device
+
input_image_size : Union[List[int], int], default=640
+
input image size
+
batch_size(int, default=4): the number of images that the neural network processes in one run
+
load_pretrained(bool, default=True): loading pretrained model
+
fp16 : bool, default=False
+
use fp16 precision
+
model_path : Union[str, pathlib.Path], default=None
+
path to the model
+
+
+

Note: REDO

+
+
+ +Expand source code + +
class TracerUniversalB7(TracerDecoder):
+    """TRACER B7 model interface"""
+
+    def __init__(
+        self,
+        device="cpu",
+        input_image_size: Union[List[int], int] = 640,
+        batch_size: int = 4,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+        model_path: Union[str, pathlib.Path] = None,
+    ):
+        """
+        Initialize the TRACER model
+
+        Args:
+            device (Literal[cpu, cuda], default=cpu): processing device
+            input_image_size (Union[List[int], int], default=640): input image size
+            batch_size(int, default=4): the number of images that the neural network processes in one run
+            load_pretrained(bool, default=True): loading pretrained model
+            fp16 (bool, default=False): use fp16 precision
+            model_path (Union[str, pathlib.Path], default=None): path to the model
+            .. note:: REDO
+        """
+        if model_path is None:
+            model_path = tracer_b7_pretrained()
+        super(TracerUniversalB7, self).__init__(
+            encoder=EfficientEncoderB7(),
+            rfb_channel=[32, 64, 128],
+            features_channels=[48, 80, 224, 640],
+        )
+
+        self.fp16 = fp16
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+
+        self.transform = transforms.Compose(
+            [
+                transforms.ToTensor(),
+                transforms.Resize(self.input_image_size),
+                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+            ]
+        )
+        self.to(device)
+        if load_pretrained:
+            # TODO remove edge detector from weights. It doesn't work well with this model!
+            self.load_state_dict(
+                torch.load(model_path, map_location=self.device), strict=False
+            )
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.FloatTensor: input for neural network
+
+        """
+
+        return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask
+
+        """
+        output = (data.type(torch.FloatTensor).detach().cpu().numpy() * 255.0).astype(
+            np.uint8
+        )
+        output = output.squeeze(0)
+        mask = Image.fromarray(output).convert("L")
+        mask = mask.resize(original_image.size, resample=Image.BILINEAR)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images (List[Union[str, pathlib.Path, PIL.Image.Image]]): input images
+
+        Returns:
+            List[PIL.Image.Image]: segmentation masks as for input images
+
+        """
+        collect_masks = []
+        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
+        with autocast:
+            cast_network(self, dtype)
+            for image_batch in batch_generator(images, self.batch_size):
+                converted_images = thread_pool_processing(
+                    lambda x: convert_image(load_image(x)), image_batch
+                )
+                batches = torch.vstack(
+                    thread_pool_processing(self.data_preprocessing, converted_images)
+                )
+                with torch.no_grad():
+                    batches = batches.to(self.device)
+                    masks = super(TracerDecoder, self).__call__(batches)
+                    masks_cpu = masks.cpu()
+                    del batches, masks
+                masks = thread_pool_processing(
+                    lambda x: self.data_postprocessing(
+                        masks_cpu[x], converted_images[x]
+                    ),
+                    range(len(converted_images)),
+                )
+                collect_masks += masks
+
+        return collect_masks
+
+

Ancestors

+ +

Static methods

+
+
+def data_postprocessing(data:Β torch.Tensor, original_image:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data : torch.Tensor
+
output data from neural network
+
original_image : PIL.Image.Image
+
input image which was used for predicted data
+
+

Returns

+
+
PIL.Image.Image
+
Segmentation mask
+
+
+ +Expand source code + +
@staticmethod
+def data_postprocessing(
+    data: torch.Tensor, original_image: PIL.Image.Image
+) -> PIL.Image.Image:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data (torch.Tensor): output data from neural network
+        original_image (PIL.Image.Image): input image which was used for predicted data
+
+    Returns:
+        PIL.Image.Image: Segmentation mask
+
+    """
+    output = (data.type(torch.FloatTensor).detach().cpu().numpy() * 255.0).astype(
+        np.uint8
+    )
+    output = output.squeeze(0)
+    mask = Image.fromarray(output).convert("L")
+    mask = mask.resize(original_image.size, resample=Image.BILINEAR)
+    return mask
+
+
+
+

Methods

+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.FloatTensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data : PIL.Image.Image
+
input image
+
+

Returns

+
+
torch.FloatTensor
+
input for neural network
+
+
+ +Expand source code + +
def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data (PIL.Image.Image): input image
+
+    Returns:
+        torch.FloatTensor: input for neural network
+
+    """
+
+    return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/api/u2net.html b/docs/api/u2net.html new file mode 100644 index 0000000..724e75b --- /dev/null +++ b/docs/api/u2net.html @@ -0,0 +1,480 @@ + + + + + + +u2net API documentation + + + + + + + + + + + +
+
+
+

Module u2net

+
+
+

Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0

+
+ +Expand source code + +
"""
+Source url: https://github.com/OPHoperHPO/image-background-remove-tool
+Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
+License: Apache License 2.0
+"""
+import pathlib
+import warnings
+
+from typing import List, Union
+import PIL.Image
+import numpy as np
+import torch
+from PIL import Image
+
+from carvekit.ml.arch.u2net.u2net import U2NETArchitecture
+from carvekit.ml.files.models_loc import u2net_full_pretrained
+from carvekit.utils.image_utils import load_image, convert_image
+from carvekit.utils.pool_utils import thread_pool_processing, batch_generator
+
+__all__ = ["U2NET"]
+
+
+class U2NET(U2NETArchitecture):
+    """U^2-Net model interface"""
+
+    def __init__(
+        self,
+        layers_cfg="full",
+        device="cpu",
+        input_image_size: Union[List[int], int] = 320,
+        batch_size: int = 10,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the U2NET model
+
+        Args:
+            layers_cfg: neural network layers configuration
+            device: processing device
+            input_image_size: input image size
+            batch_size: the number of images that the neural network processes in one run
+            load_pretrained: loading pretrained model
+            fp16: use fp16 precision // not supported at this moment.
+
+        """
+        super(U2NET, self).__init__(cfg_type=layers_cfg, out_ch=1)
+        if fp16:
+            warnings.warn("FP16 is not supported at this moment for U2NET model")
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.to(device)
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(u2net_full_pretrained(), map_location=self.device)
+            )
+
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.FloatTensor: input for neural network
+
+        """
+        resized = data.resize(self.input_image_size, resample=3)
+        # noinspection PyTypeChecker
+        resized_arr = np.array(resized, dtype=float)
+        temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+        if np.max(resized_arr) != 0:
+            resized_arr /= np.max(resized_arr)
+        temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+        temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+        temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+        temp_image = temp_image.transpose((2, 0, 1))
+        temp_image = np.expand_dims(temp_image, 0)
+        return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        data = data.unsqueeze(0)
+        mask = data[:, 0, :, :]
+        ma = torch.max(mask)  # Normalizes prediction
+        mi = torch.min(mask)
+        predict = ((mask - mi) / (ma - mi)).squeeze()
+        predict_np = predict.cpu().data.numpy() * 255
+        mask = Image.fromarray(predict_np).convert("L")
+        mask = mask.resize(original_image.size, resample=3)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images: input images
+
+        Returns:
+            segmentation masks as for input images, as PIL.Image.Image instances
+
+        """
+        collect_masks = []
+        for image_batch in batch_generator(images, self.batch_size):
+            converted_images = thread_pool_processing(
+                lambda x: convert_image(load_image(x)), image_batch
+            )
+            batches = torch.vstack(
+                thread_pool_processing(self.data_preprocessing, converted_images)
+            )
+            with torch.no_grad():
+                batches = batches.to(self.device)
+                masks, d2, d3, d4, d5, d6, d7 = super(U2NET, self).__call__(batches)
+                masks_cpu = masks.cpu()
+                del d2, d3, d4, d5, d6, d7, batches, masks
+            masks = thread_pool_processing(
+                lambda x: self.data_postprocessing(masks_cpu[x], converted_images[x]),
+                range(len(converted_images)),
+            )
+            collect_masks += masks
+        return collect_masks
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class U2NET +(layers_cfg='full', device='cpu', input_image_size:Β Union[List[int],Β int]Β =Β 320, batch_size:Β intΒ =Β 10, load_pretrained:Β boolΒ =Β True, fp16:Β boolΒ =Β False) +
+
+

U^2-Net model interface

+

Initialize the U2NET model

+

Args

+
+
layers_cfg
+
neural network layers configuration
+
device
+
processing device
+
input_image_size
+
input image size
+
batch_size
+
the number of images that the neural network processes in one run
+
load_pretrained
+
loading pretrained model
+
fp16
+
use fp16 precision // not supported at this moment.
+
+
+ +Expand source code + +
class U2NET(U2NETArchitecture):
+    """U^2-Net model interface"""
+
+    def __init__(
+        self,
+        layers_cfg="full",
+        device="cpu",
+        input_image_size: Union[List[int], int] = 320,
+        batch_size: int = 10,
+        load_pretrained: bool = True,
+        fp16: bool = False,
+    ):
+        """
+        Initialize the U2NET model
+
+        Args:
+            layers_cfg: neural network layers configuration
+            device: processing device
+            input_image_size: input image size
+            batch_size: the number of images that the neural network processes in one run
+            load_pretrained: loading pretrained model
+            fp16: use fp16 precision // not supported at this moment.
+
+        """
+        super(U2NET, self).__init__(cfg_type=layers_cfg, out_ch=1)
+        if fp16:
+            warnings.warn("FP16 is not supported at this moment for U2NET model")
+        self.device = device
+        self.batch_size = batch_size
+        if isinstance(input_image_size, list):
+            self.input_image_size = input_image_size[:2]
+        else:
+            self.input_image_size = (input_image_size, input_image_size)
+        self.to(device)
+        if load_pretrained:
+            self.load_state_dict(
+                torch.load(u2net_full_pretrained(), map_location=self.device)
+            )
+
+        self.eval()
+
+    def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+        """
+        Transform input image to suitable data format for neural network
+
+        Args:
+            data (PIL.Image.Image): input image
+
+        Returns:
+            torch.FloatTensor: input for neural network
+
+        """
+        resized = data.resize(self.input_image_size, resample=3)
+        # noinspection PyTypeChecker
+        resized_arr = np.array(resized, dtype=float)
+        temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+        if np.max(resized_arr) != 0:
+            resized_arr /= np.max(resized_arr)
+        temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+        temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+        temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+        temp_image = temp_image.transpose((2, 0, 1))
+        temp_image = np.expand_dims(temp_image, 0)
+        return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+    @staticmethod
+    def data_postprocessing(
+        data: torch.Tensor, original_image: PIL.Image.Image
+    ) -> PIL.Image.Image:
+        """
+        Transforms output data from neural network to suitable data
+        format for using with other components of this framework.
+
+        Args:
+            data (torch.Tensor): output data from neural network
+            original_image (PIL.Image.Image): input image which was used for predicted data
+
+        Returns:
+            PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+        """
+        data = data.unsqueeze(0)
+        mask = data[:, 0, :, :]
+        ma = torch.max(mask)  # Normalizes prediction
+        mi = torch.min(mask)
+        predict = ((mask - mi) / (ma - mi)).squeeze()
+        predict_np = predict.cpu().data.numpy() * 255
+        mask = Image.fromarray(predict_np).convert("L")
+        mask = mask.resize(original_image.size, resample=3)
+        return mask
+
+    def __call__(
+        self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
+    ) -> List[PIL.Image.Image]:
+        """
+        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
+
+        Args:
+            images: input images
+
+        Returns:
+            segmentation masks as for input images, as PIL.Image.Image instances
+
+        """
+        collect_masks = []
+        for image_batch in batch_generator(images, self.batch_size):
+            converted_images = thread_pool_processing(
+                lambda x: convert_image(load_image(x)), image_batch
+            )
+            batches = torch.vstack(
+                thread_pool_processing(self.data_preprocessing, converted_images)
+            )
+            with torch.no_grad():
+                batches = batches.to(self.device)
+                masks, d2, d3, d4, d5, d6, d7 = super(U2NET, self).__call__(batches)
+                masks_cpu = masks.cpu()
+                del d2, d3, d4, d5, d6, d7, batches, masks
+            masks = thread_pool_processing(
+                lambda x: self.data_postprocessing(masks_cpu[x], converted_images[x]),
+                range(len(converted_images)),
+            )
+            collect_masks += masks
+        return collect_masks
+
+

Ancestors

+ +

Static methods

+
+
+def data_postprocessing(data:Β torch.Tensor, original_image:Β PIL.Image.Image) ‑>Β PIL.Image.Image +
+
+

Transforms output data from neural network to suitable data +format for using with other components of this framework.

+

Args

+
+
data : torch.Tensor
+
output data from neural network
+
original_image : PIL.Image.Image
+
input image which was used for predicted data
+
+

Returns

+
+
PIL.Image.Image
+
Segmentation mask as PIL Image instance
+
+
+ +Expand source code + +
@staticmethod
+def data_postprocessing(
+    data: torch.Tensor, original_image: PIL.Image.Image
+) -> PIL.Image.Image:
+    """
+    Transforms output data from neural network to suitable data
+    format for using with other components of this framework.
+
+    Args:
+        data (torch.Tensor): output data from neural network
+        original_image (PIL.Image.Image): input image which was used for predicted data
+
+    Returns:
+        PIL.Image.Image: Segmentation mask as `PIL Image` instance
+
+    """
+    data = data.unsqueeze(0)
+    mask = data[:, 0, :, :]
+    ma = torch.max(mask)  # Normalizes prediction
+    mi = torch.min(mask)
+    predict = ((mask - mi) / (ma - mi)).squeeze()
+    predict_np = predict.cpu().data.numpy() * 255
+    mask = Image.fromarray(predict_np).convert("L")
+    mask = mask.resize(original_image.size, resample=3)
+    return mask
+
+
+
+

Methods

+
+
+def data_preprocessing(self, data:Β PIL.Image.Image) ‑>Β torch.FloatTensor +
+
+

Transform input image to suitable data format for neural network

+

Args

+
+
data : PIL.Image.Image
+
input image
+
+

Returns

+
+
torch.FloatTensor
+
input for neural network
+
+
+ +Expand source code + +
def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
+    """
+    Transform input image to suitable data format for neural network
+
+    Args:
+        data (PIL.Image.Image): input image
+
+    Returns:
+        torch.FloatTensor: input for neural network
+
+    """
+    resized = data.resize(self.input_image_size, resample=3)
+    # noinspection PyTypeChecker
+    resized_arr = np.array(resized, dtype=float)
+    temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
+    if np.max(resized_arr) != 0:
+        resized_arr /= np.max(resized_arr)
+    temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
+    temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
+    temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
+    temp_image = temp_image.transpose((2, 0, 1))
+    temp_image = np.expand_dims(temp_image, 0)
+    return torch.from_numpy(temp_image).type(torch.FloatTensor)
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/imgs/input/1_bg_removed.png b/docs/imgs/input/1_bg_removed.png index a1e44f6..c0443fd 100644 Binary files a/docs/imgs/input/1_bg_removed.png and b/docs/imgs/input/1_bg_removed.png differ diff --git a/docs/imgs/input/2_bg_removed.png b/docs/imgs/input/2_bg_removed.png index a30c041..86e7097 100644 Binary files a/docs/imgs/input/2_bg_removed.png and b/docs/imgs/input/2_bg_removed.png differ diff --git a/docs/imgs/input/3_bg_removed.png b/docs/imgs/input/3_bg_removed.png index 298e17f..5157287 100644 Binary files a/docs/imgs/input/3_bg_removed.png and b/docs/imgs/input/3_bg_removed.png differ diff --git a/docs/imgs/input/4_bg_removed.png b/docs/imgs/input/4_bg_removed.png index 32b6a1c..89097c1 100644 Binary files a/docs/imgs/input/4_bg_removed.png and b/docs/imgs/input/4_bg_removed.png differ diff --git a/docs/other/carvekit_try.ipynb b/docs/other/carvekit_try.ipynb index 484ee9c..35989bd 100644 --- a/docs/other/carvekit_try.ipynb +++ b/docs/other/carvekit_try.ipynb @@ -1,204 +1,180 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "carvekit-try.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU", - "gpuClass": "standard" + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "carvekit-try.ipynb", + "provenance": [], + "collapsed_sections": [] }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "![logo.png]()" - ], - "metadata": { - "id": "-BV5wSJzQ-ev", - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "### Automated high-quality background removal framework for an image using neural networks\n", - "\n", - "\n", - "\n", - "- 🏒 [Project at GitHub](https://github.com/OPHoperHPO/image-background-remove-tool) 🏒\n", - "- πŸ”— [Author at GitHub](https://github.com/OPHoperHPO) πŸ”—\n", - "\n", - "> Please rate our repository with ⭐ if you like our work! Thanks! πŸ˜€" - ], - "metadata": { - "id": "Yq1sa5BbRV4c", - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "This notebook supports **Google Colab GPU runtime**. \n", - "\n", - "> **Enabling and testing the GPU** \\\n", - "> Navigate to `Edit β†’ Notebook Settings`. \\\n", - "> Select `GPU` from the `Hardware Accelerator` drop-down." - ], - "metadata": { - "id": "lrGOILABYqXx", - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sqwsUfoI3SnG", - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "# Install CarveKit" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "7C4rC_HQi1gq", - "pycharm": { - "name": "#%%\n" - } - }, - "source": [ - "#@title Install colab-ready python package (Click the arrow on the left)\n", - "%cd /content\n", - "!pip install carvekit_colab\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "#@title Download all models\n", - "from carvekit.ml.files.models_loc import download_all\n", - "\n", - "download_all();" - ], - "metadata": { - "cellView": "form", - "id": "EPjtRXRpQ2k7", - "pycharm": { - "name": "#%%\n" - } - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pF-4SVcB3gjK", - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "# Remove background using CarveKit" - ] + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "![logo.png]()" + ], + "metadata": { + "id": "-BV5wSJzQ-ev" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Automated high-quality background removal framework for an image using neural networks\n", + "\n", + "\n", + "\n", + "- 🏒 [Project at GitHub](https://github.com/OPHoperHPO/image-background-remove-tool) 🏒\n", + "- πŸ”— [Author at GitHub](https://github.com/OPHoperHPO) πŸ”—\n", + "\n", + "> Please rate our repository with ⭐ if you like our work! Thanks! πŸ˜€" + ], + "metadata": { + "id": "Yq1sa5BbRV4c" + } + }, + { + "cell_type": "markdown", + "source": [ + "This notebook supports **Google Colab GPU runtime**. \n", + "\n", + "> **Enabling and testing the GPU** \\\n", + "> Navigate to `Edit β†’ Notebook Settings`. \\\n", + "> Select `GPU` from the `Hardware Accelerator` drop-down." + ], + "metadata": { + "id": "lrGOILABYqXx" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sqwsUfoI3SnG" + }, + "source": [ + "# Install CarveKit" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "7C4rC_HQi1gq" + }, + "source": [ + "#@title Install colab-ready python package (Click the arrow on the left)\n", + "%cd /content\n", + "!pip install carvekit_colab\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Download all models\n", + "from carvekit.ml.files.models_loc import download_all\n", + "\n", + "download_all();" + ], + "metadata": { + "cellView": "form", + "id": "EPjtRXRpQ2k7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pF-4SVcB3gjK" + }, + "source": [ + "# Remove background using CarveKit" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "rgm6pR6U22a9", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 529 }, + "cellView": "form", + "outputId": "a908d208-0520-42ec-dbe0-c06e6c4ee260" + }, + "source": [ + "#@title Upload images from your computer\n", + "#@markdown Description of parameters\n", + "#@markdown - `SHOW_FULLSIZE` - Shows image in full size (may take a long time to load)\n", + "#@markdown - `PREPROCESSING_METHOD` - Preprocessing method. `AutoScene` will automatically select needed model depends on your image. If you don't want, disable it.\n", + "#@markdown - `SEGMENTATION_NETWORK` - Segmentation network. Use `u2net` for hairs-like objects and `tracer_b7` for objects\n", + "#@markdown - `POSTPROCESSING_METHOD` - Postprocessing method\n", + "#@markdown - `SEGMENTATION_MASK_SIZE` - Segmentation mask size. Use 640 for Tracer B7 and 320 for U2Net\n", + "#@markdown - `TRIMAP_DILATION` - The size of the offset radius from the object mask in pixels when forming an unknown area\n", + "#@markdown - `TRIMAP_EROSION` - The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area\n", + "#@markdown > Look README.md and code for more details on networks and methods\n", + "\n", + "\n", + "import torch\n", + "from IPython import display\n", + "from google.colab import files\n", + "from carvekit.web.schemas.config import MLConfig\n", + "from carvekit.web.utils.init_utils import init_interface\n", + "\n", + "SHOW_FULLSIZE = False #@param {type:\"boolean\"}\n", + "PREPROCESSING_METHOD = \"autoscene\" #@param [\"autoscene\", \"auto\", \"none\"]\n", + "SEGMENTATION_NETWORK = \"tracer_b7\" #@param [\"u2net\", \"deeplabv3\", \"basnet\", \"tracer_b7\"]\n", + "POSTPROCESSING_METHOD = \"cascade_fba\" #@param [\"fba\", \"cascade_fba\", \"none\"]\n", + "SEGMENTATION_MASK_SIZE = 640 #@param [\"640\", \"320\"] {type:\"raw\", allow-input: true}\n", + "TRIMAP_DILATION = 30 #@param {type:\"integer\"}\n", + "TRIMAP_EROSION = 5 #@param {type:\"integer\"}\n", + "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "\n", + "config = MLConfig(segmentation_network=SEGMENTATION_NETWORK,\n", + " preprocessing_method=PREPROCESSING_METHOD,\n", + " postprocessing_method=POSTPROCESSING_METHOD,\n", + " seg_mask_size=SEGMENTATION_MASK_SIZE,\n", + " trimap_dilation=TRIMAP_DILATION,\n", + " trimap_erosion=TRIMAP_EROSION,\n", + " device=DEVICE)\n", + "\n", + "\n", + "interface = init_interface(config)\n", + "\n", + "\n", + "\n", + "\n", + "uploaded = files.upload().keys()\n", + "display.clear_output()\n", + "images = interface(uploaded)\n", + "for im in enumerate(images):\n", + " if not SHOW_FULLSIZE:\n", + " im[1].thumbnail((768, 768), resample=3)\n", + " display.display(im[1])\n", + "\n" + ], + "execution_count": 5, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "rgm6pR6U22a9", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 529 - }, - "cellView": "form", - "outputId": "a908d208-0520-42ec-dbe0-c06e6c4ee260", - "pycharm": { - "name": "#%%\n" - } - }, - "source": [ - "#@title Upload images from your computer\n", - "#@markdown Description of parameters\n", - "#@markdown - `SHOW_FULLSIZE` - Shows image in full size (may take a long time to load)\n", - "#@markdown - `PREPROCESSING_METHOD` - Preprocessing method\n", - "#@markdown - `SEGMENTATION_NETWORK` - Segmentation network. Use `u2net` for hairs-like objects and `tracer_b7` for objects\n", - "#@markdown - `POSTPROCESSING_METHOD` - Postprocessing method\n", - "#@markdown - `SEGMENTATION_MASK_SIZE` - Segmentation mask size. Use 640 for Tracer B7 and 320 for U2Net\n", - "#@markdown - `TRIMAP_DILATION` - The size of the offset radius from the object mask in pixels when forming an unknown area\n", - "#@markdown - `TRIMAP_EROSION` - The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area\n", - "#@markdown > Look README.md and code for more details on networks and methods\n", - "\n", - "\n", - "import torch\n", - "from IPython import display\n", - "from google.colab import files\n", - "from carvekit.web.schemas.config import MLConfig\n", - "from carvekit.web.utils.init_utils import init_interface\n", - "\n", - "SHOW_FULLSIZE = False #@param {type:\"boolean\"}\n", - "PREPROCESSING_METHOD = \"none\" #@param [\"stub\", \"none\"]\n", - "SEGMENTATION_NETWORK = \"tracer_b7\" #@param [\"u2net\", \"deeplabv3\", \"basnet\", \"tracer_b7\"]\n", - "POSTPROCESSING_METHOD = \"fba\" #@param [\"fba\", \"none\"] \n", - "SEGMENTATION_MASK_SIZE = 640 #@param [\"640\", \"320\"] {type:\"raw\", allow-input: true}\n", - "TRIMAP_DILATION = 30 #@param {type:\"integer\"}\n", - "TRIMAP_EROSION = 5 #@param {type:\"integer\"}\n", - "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", - "\n", - "\n", - "config = MLConfig(segmentation_network=SEGMENTATION_NETWORK,\n", - " preprocessing_method=PREPROCESSING_METHOD,\n", - " postprocessing_method=POSTPROCESSING_METHOD,\n", - " seg_mask_size=SEGMENTATION_MASK_SIZE,\n", - " trimap_dilation=TRIMAP_DILATION,\n", - " trimap_erosion=TRIMAP_EROSION,\n", - " device=DEVICE)\n", - "\n", - "\n", - "interface = init_interface(config)\n", - "\n", - "\n", - "\n", - "\n", - "uploaded = files.upload().keys()\n", - "display.clear_output()\n", - "images = interface(uploaded)\n", - "for im in enumerate(images):\n", - " if not SHOW_FULLSIZE:\n", - " im[1].thumbnail((768, 768), resample=3)\n", - " display.display(im[1])\n", - "\n" + "output_type": "display_data", + "data": { + "text/plain": [ + "" ], - "execution_count": 5, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "" - ], - "image/png": "\n" - }, - "metadata": {} - } - ] + "image/png": "\n" + }, + "metadata": {} } - ] -} \ No newline at end of file + ] + } + ] +} diff --git a/docs/readme/ru.md b/docs/readme/ru.md index c001adc..4dea7ea 100644 --- a/docs/readme/ru.md +++ b/docs/readme/ru.md @@ -25,13 +25,16 @@ ## πŸŽ† ΠžΡΠΎΠ±Π΅Π½Π½ΠΎΡΡ‚ΠΈ: - ВысокоС качСство Π²Ρ‹Ρ…ΠΎΠ΄Π½ΠΎΠ³ΠΎ изобраТСния +- Π Π°Π±ΠΎΡ‚Π°Π΅Ρ‚ Π² Π°Π²Ρ‚ΠΎΠ½ΠΎΠΌΠ½ΠΎΠΌ Ρ€Π΅ΠΆΠΈΠΌΠ΅ - ΠŸΠ°ΠΊΠ΅Ρ‚Π½Π°Ρ ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ° ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ - ΠŸΠΎΠ΄Π΄Π΅Ρ€ΠΆΠΊΠ° NVIDIA CUDA ΠΈ процСссорной ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ - ΠŸΠΎΠ΄Π΄Π΅Ρ€ΠΆΠΊΠ° FP16: быстрая ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ° с Π½ΠΈΠ·ΠΊΠΈΠΌ ΠΏΠΎΡ‚Ρ€Π΅Π±Π»Π΅Π½ΠΈΠ΅ΠΌ памяти - Π›Π΅Π³ΠΊΠΎΠ΅ взаимодСйствиС ΠΈ запуск - 100% совмСстимоС с remove.bg API FastAPI HTTP API - УдаляСт Ρ„ΠΎΠ½ с волос +- АвтоматичСский Π²Ρ‹Π±ΠΎΡ€ Π»ΡƒΡ‡ΡˆΠ΅Π³ΠΎ ΠΌΠ΅Ρ‚ΠΎΠ΄Π° для изобраТСния ΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»Ρ - ΠŸΡ€ΠΎΡΡ‚Π°Ρ интСграция с вашим ΠΊΠΎΠ΄ΠΎΠΌ +- МодСли Ρ€Π°Π·ΠΌΠ΅Ρ‰Π΅Π½Ρ‹ Π½Π° [HuggingFace](https://huggingface.co/Carve) ## β›± ΠŸΠΎΠΏΡ€ΠΎΠ±ΡƒΠΉΡ‚Π΅ сами Π½Π° [Google Colab](https://colab.research.google.com/github/OPHoperHPO/image-background-remove-tool/blob/master/docs/other/carvekit_try.ipynb) ## ⛓️ Как это Ρ€Π°Π±ΠΎΡ‚Π°Π΅Ρ‚? @@ -40,6 +43,7 @@ 2. ΠŸΡ€ΠΎΠΈΡΡ…ΠΎΠ΄ΠΈΡ‚ ΠΏΡ€Π΅Π΄ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ° Ρ„ΠΎΡ‚ΠΎΠ³Ρ€Π°Ρ„ΠΈΠΈ для обСспСчСния Π»ΡƒΡ‡ΡˆΠ΅Π³ΠΎ качСства Π²Ρ‹Ρ…ΠΎΠ΄Π½ΠΎΠ³ΠΎ изобраТСния 3. Π‘ ΠΏΠΎΠΌΠΎΡ‰ΡŒΡŽ Ρ‚Π΅Ρ…Π½ΠΎΠ»ΠΎΠ³ΠΈΠΈ машинного обучСния убираСтся Ρ„ΠΎΠ½ Ρƒ изобраТСния 4. ΠŸΡ€ΠΎΠΈΡΡ…ΠΎΠ΄ΠΈΡ‚ постобработка изобраТСния для ΡƒΠ»ΡƒΡ‡ΡˆΠ΅Π½ΠΈΡ качСства ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚Π°Π½Π½ΠΎΠ³ΠΎ изобраТСния + ## πŸŽ“ Implemented Neural Networks: | НСйронныС сСти | ЦСлСвая ΠΎΠ±Π»Π°ΡΡ‚ΡŒ | Π’ΠΎΡ‡Π½ΠΎΡΡ‚ΡŒ | |:--------------:|:--------------------------------------------:|:--------------------------------:| @@ -47,14 +51,35 @@ | U^2-net | **Волосы** (hairs, people, animals, objects) | 80% (mean F1-Score, DUTS-TE) | | BASNet | **ΠžΠ±Ρ‰ΠΈΠΉ** (people, objects) | 80% (mean F1-Score, DUTS-TE) | | DeepLabV3 | People, Animals, Cars, etc | 67.4% (mean IoU, COCO val2017) | -> Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠΉΡ‚Π΅ U2-Net для волос ΠΈ Tracer-B7 для ΠΎΠ±Ρ‹Ρ‡Π½Ρ‹Ρ… ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ. -## πŸ–ΌοΈ ΠœΠ΅Ρ‚ΠΎΠ΄Ρ‹ ΠΏΡ€Π΅Π΄Π²Π°Ρ€ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ ΠΈ постобработки ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ: -### πŸ” ΠœΠ΅Ρ‚ΠΎΠ΄Ρ‹ ΠΏΡ€Π΅Π΄ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ: -* `none` - ΠΌΠ΅Ρ‚ΠΎΠ΄Ρ‹ ΠΏΡ€Π΅Π΄ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ Π½Π΅ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΡŽΡ‚ΡΡ. -> Они Π±ΡƒΠ΄ΡƒΡ‚ Π΄ΠΎΠ±Π°Π²Π»Π΅Π½Ρ‹ Π² Π±ΡƒΠ΄ΡƒΡ‰Π΅ΠΌ. + +### Recommended parameters for different models +| НСйронныС сСти | Π Π°Π·ΠΌΠ΅Ρ€ маски сСгмСнтации | ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ Trimap (Ρ€Π°ΡΡˆΠΈΡ€Π΅Π½ΠΈΠ΅, эрозия) | +|:--------------:|:------------------------:|:-------------------------------------:| +| `tracer_b7` | 640 | (30, 5) | +| `u2net` | 320 | (30, 5) | +| `basnet` | 320 | (30, 5) | +| `deeplabv3` | 1024 | (40, 20) | + +> ### Notes: +> 1. ΠžΠΊΠΎΠ½Ρ‡Π°Ρ‚Π΅Π»ΡŒΠ½ΠΎΠ΅ качСство ΠΌΠΎΠΆΠ΅Ρ‚ Π·Π°Π²ΠΈΡΠ΅Ρ‚ΡŒ ΠΎΡ‚ Ρ€Π°Π·Ρ€Π΅ΡˆΠ΅Π½ΠΈΡ вашСго изобраТСния, Ρ‚ΠΈΠΏΠ° сцСны ΠΈΠ»ΠΈ ΠΎΠ±ΡŠΠ΅ΠΊΡ‚Π°. +> 2. Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠΉΡ‚Π΅ U2-Net для волос ΠΈ Tracer-B7 для ΠΎΠ±Ρ‰ΠΈΡ… ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ ΠΈ ΠΏΡ€Π°Π²ΠΈΠ»ΡŒΠ½Ρ‹Ρ… ΠΏΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€ΠΎΠ². \ +> Π­Ρ‚ΠΎ ΠΎΡ‡Π΅Π½ΡŒ Π²Π°ΠΆΠ½ΠΎ для ΠΊΠΎΠ½Π΅Ρ‡Π½ΠΎΠ³ΠΎ качСства! ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ Π±Ρ‹Π»ΠΈ ΠΏΠΎΠ»ΡƒΡ‡Π΅Π½Ρ‹ с использованиСм постобработки U2-Net ΠΈ FBA. + +## πŸ–ΌοΈ Image pre-processing and post-processing methods: +### πŸ” Preprocessing methods: +* `none` - No preprocessing methods used. +* [`autoscene`](https://huggingface.co/Carve/scene_classifier/) - АвтоматичСски опрСдСляСт Ρ‚ΠΈΠΏ сцСны с ΠΏΠΎΠΌΠΎΡ‰ΡŒΡŽ классификатора ΠΈ примСняСт ΡΠΎΠΎΡ‚Π²Π΅Ρ‚ΡΡ‚Π²ΡƒΡŽΡ‰ΡƒΡŽ модСль. (По ΡƒΠΌΠΎΠ»Ρ‡Π°Π½ΠΈΡŽ) +* `auto` - ВыполняСт Π³Π»ΡƒΠ±ΠΎΠΊΠΈΠΉ Π°Π½Π°Π»ΠΈΠ· изобраТСния ΠΈ Π±ΠΎΠ»Π΅Π΅ Ρ‚ΠΎΡ‡Π½ΠΎ опрСдСляСт Π»ΡƒΡ‡ΡˆΠΈΠΉ ΠΌΠ΅Ρ‚ΠΎΠ΄ удалСния Ρ„ΠΎΠ½Π°. Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ классификатор ΠΎΠ±ΡŠΠ΅ΠΊΡ‚ΠΎΠ² ΠΈ классификатор сцСны вмСстС. +> ### Notes: +> 1. `AutoScene` ΠΈ `auto` ΠΌΠΎΠ³ΡƒΡ‚ ΠΏΠ΅Ρ€Π΅ΠΎΠΏΡ€Π΅Π΄Π΅Π»ΠΈΡ‚ΡŒ модСль ΠΈ ΠΏΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹, ΡƒΠΊΠ°Π·Π°Π½Π½Ρ‹Π΅ ΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»Π΅ΠΌ, Π±Π΅Π· увСдомлСния. +> Π˜Ρ‚Π°ΠΊ, Ссли Π²Ρ‹ Ρ…ΠΎΡ‚ΠΈΡ‚Π΅ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚ΡŒ ΠΊΠΎΠ½ΠΊΡ€Π΅Ρ‚Π½ΡƒΡŽ модСль, ΡΠ΄Π΅Π»Π°Ρ‚ΡŒ всС постоянными ΠΈ Ρ‚. Π΄., Π²Π°ΠΌ слСдуСт сначала ΠΎΡ‚ΠΊΠ»ΡŽΡ‡ΠΈΡ‚ΡŒ ΠΌΠ΅Ρ‚ΠΎΠ΄Ρ‹ автоматичСской ΠΏΡ€Π΅Π΄Π²Π°Ρ€ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ! +> 2. На Π΄Π°Π½Π½Ρ‹ΠΉ ΠΌΠΎΠΌΠ΅Π½Ρ‚ для ΠΌΠ΅Ρ‚ΠΎΠ΄Π° `auto` Π²Ρ‹Π±ΠΈΡ€Π°ΡŽΡ‚ΡΡ ΡƒΠ½ΠΈΠ²Π΅Ρ€ΡΠ°Π»ΡŒΠ½Ρ‹Π΅ ΠΌΠΎΠ΄Π΅Π»ΠΈ для Π½Π΅ΠΊΠΎΡ‚ΠΎΡ€Ρ‹Ρ… ΠΊΠΎΠ½ΠΊΡ€Π΅Ρ‚Π½Ρ‹Ρ… Π΄ΠΎΠΌΠ΅Π½ΠΎΠ², Ρ‚Π°ΠΊ ΠΊΠ°ΠΊ Π΄ΠΎΠ±Π°Π²Π»Π΅Π½Π½Ρ‹Ρ… ΠΌΠΎΠ΄Π΅Π»Π΅ΠΉ Π² настоящСС врСмя нСдостаточно для Ρ‚Π°ΠΊΠΎΠ³ΠΎ количСства Ρ‚ΠΈΠΏΠΎΠ² сцСн. +> Π’ Π±ΡƒΠ΄ΡƒΡ‰Π΅ΠΌ, ΠΊΠΎΠ³Π΄Π° Π±ΡƒΠ΄Π΅Ρ‚ Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΎ Π½Π΅ΠΊΠΎΡ‚ΠΎΡ€ΠΎΠ΅ Ρ€Π°Π·Π½ΠΎΠΎΠ±Ρ€Π°Π·ΠΈΠ΅ ΠΌΠΎΠ΄Π΅Π»Π΅ΠΉ, Π°Π²Ρ‚ΠΎΠΏΠΎΠ΄Π±ΠΎΡ€ Π±ΡƒΠ΄Π΅Ρ‚ пСрСписан Π² Π»ΡƒΡ‡ΡˆΡƒΡŽ сторону. + ### βœ‚ ΠœΠ΅Ρ‚ΠΎΠ΄Ρ‹ постобработки: * `none` - ΠΌΠ΅Ρ‚ΠΎΠ΄Ρ‹ постобработки Π½Π΅ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΡŽΡ‚ΡΡ -* `fba` (ΠΏΠΎ ΡƒΠΌΠΎΠ»Ρ‡Π°Π½ΠΈΡŽ) - Π­Ρ‚ΠΎΡ‚ Π°Π»Π³ΠΎΡ€ΠΈΡ‚ΠΌ ΡƒΠ»ΡƒΡ‡ΡˆΠ°Π΅Ρ‚ Π³Ρ€Π°Π½ΠΈΡ†Ρ‹ изобраТСния ΠΏΡ€ΠΈ ΡƒΠ΄Π°Π»Π΅Π½ΠΈΠΈ Ρ„ΠΎΠ½Π° с ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ с волосами ΠΈ Ρ‚.Π΄. с ΠΏΠΎΠΌΠΎΡ‰ΡŒΡŽ Π½Π΅ΠΉΡ€ΠΎΠ½Π½ΠΎΠΉ сСти FBA Matting. Π­Ρ‚ΠΎΡ‚ ΠΌΠ΅Ρ‚ΠΎΠ΄ Π΄Π°Π΅Ρ‚ Π½Π°ΠΈΠ»ΡƒΡ‡ΡˆΠΈΠΉ Ρ€Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚ Π² сочСтании с u2net Π±Π΅Π· ΠΊΠ°ΠΊΠΈΡ…-Π»ΠΈΠ±ΠΎ ΠΌΠ΅Ρ‚ΠΎΠ΄ΠΎΠ² ΠΏΡ€Π΅Π΄Π²Π°Ρ€ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ. +* `fba` - Π­Ρ‚ΠΎΡ‚ Π°Π»Π³ΠΎΡ€ΠΈΡ‚ΠΌ ΡƒΠ»ΡƒΡ‡ΡˆΠ°Π΅Ρ‚ Π³Ρ€Π°Π½ΠΈΡ†Ρ‹ изобраТСния ΠΏΡ€ΠΈ ΡƒΠ΄Π°Π»Π΅Π½ΠΈΠΈ Ρ„ΠΎΠ½Π° с ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ с волосами ΠΈ Ρ‚.Π΄. с ΠΏΠΎΠΌΠΎΡ‰ΡŒΡŽ Π½Π΅ΠΉΡ€ΠΎΠ½Π½ΠΎΠΉ сСти FBA Matting. +* `cascade_fba` (default) - Π­Ρ‚ΠΎΡ‚ Π°Π»Π³ΠΎΡ€ΠΈΡ‚ΠΌ уточняСт маску сСгмСнтации с ΠΏΠΎΠΌΠΎΡ‰ΡŒΡŽ Π½Π΅ΠΉΡ€ΠΎΠ½Π½ΠΎΠΉ сСти CascadePSP, Π° Π·Π°Ρ‚Π΅ΠΌ примСняСт Π°Π»Π³ΠΎΡ€ΠΈΡ‚ΠΌ FBA. ## 🏷 Настройка для ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ Π½Π° CPU: 1. `pip install carvekit --extra-index-url https://download.pytorch.org/whl/cpu` @@ -62,7 +87,7 @@ ## 🏷 Настройка для ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ Π½Π° GPU: 1. Π£Π±Π΅Π΄ΠΈΡ‚Π΅ΡΡŒ, Ρ‡Ρ‚ΠΎ Ρƒ вас Π΅ΡΡ‚ΡŒ графичСский процСссор NVIDIA с 8 Π“Π‘ видСопамяти. -2. УстановитС `CUDA Toolkit ΠΈ Π’ΠΈΠ΄Π΅ΠΎ Π΄Ρ€Π°Π²Π΅Ρ€ для вашСй Π²ΠΈΠ΄Π΅ΠΎΠΊΠ°Ρ€Ρ‚Ρ‹.` +2. УстановитС `CUDA Toolkit ΠΈ Π’ΠΈΠ΄Π΅ΠΎΠ΄Ρ€Π°ΠΉΠ²Π΅Ρ€ для вашСй Π²ΠΈΠ΄Π΅ΠΎΠΊΠ°Ρ€Ρ‚Ρ‹.` 3. `pip install carvekit --extra-index-url https://download.pytorch.org/whl/cu113` > ΠŸΡ€ΠΎΠ΅ΠΊΡ‚ ΠΏΠΎΠ΄Π΄Π΅Ρ€ΠΆΠΈΠ²Π°Π΅Ρ‚ вСрсии Python ΠΎΡ‚ 3.8 Π΄ΠΎ 3.10.4. @@ -73,12 +98,15 @@ import torch from carvekit.api.high import HiInterface # Check doc strings for more information -interface = HiInterface(object_type="hairs-like", # Can be "object" or "hairs-like". +interface = HiInterface(object_type="auto", # Can be "object" or "hairs-like" or "auto" batch_size_seg=5, + batch_size_pre=5, batch_size_matting=1, + batch_size_refine=1, device='cuda' if torch.cuda.is_available() else 'cpu', - seg_mask_size=640, + seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net matting_mask_size=2048, + refine_mask_size=900, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, @@ -89,33 +117,65 @@ cat_wo_bg.save('2.png') ``` - +### Аналог ΠΌΠ΅Ρ‚ΠΎΠ΄Π° ΠΏΡ€Π΅Π΄Π²Π°Ρ€ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ `auto` ΠΈΠ· cli +``` python +from carvekit.api.autointerface import AutoInterface +from carvekit.ml.wrap.scene_classifier import SceneClassifier +from carvekit.ml.wrap.yolov4 import SimplifiedYoloV4 + +scene_classifier = SceneClassifier(device="cpu", batch_size=1) +object_classifier = SimplifiedYoloV4(device="cpu", batch_size=1) + +interface = AutoInterface(scene_classifier=scene_classifier, + object_classifier=object_classifier, + segmentation_batch_size=1, + postprocessing_batch_size=1, + postprocessing_image_size=2048, + refining_batch_size=1, + refining_image_size=900, + segmentation_device="cpu", + fp16=False, + postprocessing_device="cpu") +images_without_background = interface(['./tests/data/cat.jpg']) +cat_wo_bg = images_without_background[0] +cat_wo_bg.save('2.png') +``` ### Если Π²Ρ‹ Ρ…ΠΎΡ‚ΠΈΡ‚Π΅ провСсти Π΄Π΅Ρ‚Π°Π»ΡŒΠ½ΡƒΡŽ настройку ``` python import PIL.Image from carvekit.api.interface import Interface from carvekit.ml.wrap.fba_matting import FBAMatting +from carvekit.ml.wrap.scene_classifier import SceneClassifier +from carvekit.ml.wrap.cascadepsp import CascadePSP from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 -from carvekit.pipelines.postprocessing import MattingMethod -from carvekit.pipelines.preprocessing import PreprocessingStub +from carvekit.pipelines.postprocessing import CasMattingMethod +from carvekit.pipelines.preprocessing import AutoScene from carvekit.trimap.generator import TrimapGenerator # Check doc strings for more information seg_net = TracerUniversalB7(device='cpu', - batch_size=1) - + batch_size=1, fp16=False) +cascade_psp = CascadePSP(device='cpu', + batch_size=1, + input_tensor_size=900, + fp16=False, + processing_accelerate_image_size=2048, + global_step_only=False) fba = FBAMatting(device='cpu', input_tensor_size=2048, - batch_size=1) + batch_size=1, fp16=False) -trimap = TrimapGenerator() +trimap = TrimapGenerator(prob_threshold=231, kernel_size=30, erosion_iters=5) -preprocessing = PreprocessingStub() +scene_classifier = SceneClassifier(device='cpu', batch_size=5) +preprocessing = AutoScene(scene_classifier=scene_classifier) -postprocessing = MattingMethod(matting_module=fba, - trimap_generator=trimap, - device='cpu') +postprocessing = CasMattingMethod( + refining_module=cascade_psp, + matting_module=fba, + trimap_generator=trimap, + device='cpu') interface = Interface(pre_pipe=preprocessing, post_pipe=postprocessing, @@ -123,8 +183,7 @@ interface = Interface(pre_pipe=preprocessing, image = PIL.Image.open('tests/data/cat.jpg') cat_wo_bg = interface([image])[0] -cat_wo_bg.save('2.png') - +cat_wo_bg.save('2.png') ``` @@ -140,24 +199,27 @@ Usage: carvekit [OPTIONS] Options: -i ./2.jpg ΠŸΡƒΡ‚ΡŒ Π΄ΠΎ Π²Ρ…ΠΎΠ΄Π½ΠΎΠ³ΠΎ Ρ„Π°ΠΉΠ»Π° ΠΈΠ»ΠΈ Π΄ΠΈΡ€Π΅ΠΊΡ‚ΠΎΡ€ΠΈΠΈ [обязатСлСн] -o ./2.png ΠŸΡƒΡ‚ΡŒ для сохранСния Ρ€Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚Π° ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ - --pre none ΠœΠ΅Ρ‚ΠΎΠ΄ ΠΏΡ€Π΅Π΄ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ - --post fba ΠœΠ΅Ρ‚ΠΎΠ΄ постобработки - --net u2net НСйронная ΡΠ΅Ρ‚ΡŒ для сСгмСнтации + --pre autoscene ΠœΠ΅Ρ‚ΠΎΠ΄ ΠΏΡ€Π΅Π΄ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ + --post cascade_fba ΠœΠ΅Ρ‚ΠΎΠ΄ постобработки + --net tracer_b7 НСйронная ΡΠ΅Ρ‚ΡŒ для сСгмСнтации --recursive Π’ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ рСкурсивного поиска ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ Π² ΠΏΠ°ΠΏΠΊΠ΅ --batch_size 10 Π Π°Π·ΠΌΠ΅Ρ€ ΠΏΠ°ΠΊΠ΅Ρ‚Π° ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ, Π·Π°Π³Ρ€ΡƒΠΆΠ΅Π½Π½Ρ‹Ρ… Π² ΠžΠ—Π£ - + --batch_size_pre 5 Π Π°Π·ΠΌΠ΅Ρ€ ΠΏΠ°ΠΊΠ΅Ρ‚Π° для списка ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹Π΅ Π±ΡƒΠ΄ΡƒΡ‚ ΠΎΠ±Ρ€Π°Π±Π°Ρ‚Ρ‹Π²Π°Ρ‚ΡŒΡΡ + ΠΌΠ΅Ρ‚ΠΎΠ΄ΠΎΠΌ ΠΏΡ€Π΅Π΄Π²Π°Ρ€ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ --batch_size_seg 5 Π Π°Π·ΠΌΠ΅Ρ€ ΠΏΠ°ΠΊΠ΅Ρ‚Π° ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ для ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ с ΠΏΠΎΠΌΠΎΡ‰ΡŒΡŽ сСгмСнтации --batch_size_mat 1 Π Π°Π·ΠΌΠ΅Ρ€ ΠΏΠ°ΠΊΠ΅Ρ‚Π° ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ для ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ с ΠΏΠΎΠΌΠΎΡ‰ΡŒΡŽ матирования - --seg_mask_size 320 Π Π°Π·ΠΌΠ΅Ρ€ исходного изобраТСния для ΡΠ΅Π³ΠΌΠ΅Π½Ρ‚ΠΈΡ€ΡƒΡŽΡ‰Π΅ΠΉ + --batch_size_refine 1 Π Π°Π·ΠΌΠ΅Ρ€ ΠΏΠ°ΠΊΠ΅Ρ‚Π° для списка ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹Π΅ Π±ΡƒΠ΄ΡƒΡ‚ ΠΎΠ±Ρ€Π°Π±Π°Ρ‚Ρ‹Π²Π°Ρ‚ΡŒΡΡ ΡƒΡ‚ΠΎΡ‡Π½ΡΡŽΡ‰Π΅ΠΉ ΡΠ΅Ρ‚ΡŒΡŽ + + --seg_mask_size 640 Π Π°Π·ΠΌΠ΅Ρ€ исходного изобраТСния для ΡΠ΅Π³ΠΌΠ΅Π½Ρ‚ΠΈΡ€ΡƒΡŽΡ‰Π΅ΠΉ Π½Π΅ΠΉΡ€ΠΎΠ½Π½ΠΎΠΉ сСти --matting_mask_size 2048 Π Π°Π·ΠΌΠ΅Ρ€ исходного изобраТСния для ΠΌΠ°Ρ‚ΠΈΡ€ΡƒΡŽΡ‰Π΅ΠΉ Π½Π΅ΠΉΡ€ΠΎΠ½Π½ΠΎΠΉ сСти - + --refine_mask_size 900 Π Π°Π·ΠΌΠ΅Ρ€ Π²Ρ…ΠΎΠ΄Π½ΠΎΠ³ΠΎ изобраТСния для ΡƒΡ‚ΠΎΡ‡Π½ΡΡŽΡ‰Π΅ΠΉ Π½Π΅ΠΉΡ€ΠΎΠ½Π½ΠΎΠΉ сСти. --trimap_dilation 30 Π Π°Π·ΠΌΠ΅Ρ€ радиуса смСщСния ΠΎΡ‚ маски ΠΎΠ±ΡŠΠ΅ΠΊΡ‚Π° Π² пиксСлях ΠΏΡ€ΠΈ Ρ„ΠΎΡ€ΠΌΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠΈ нСизвСстной области diff --git a/requirements_dev.txt b/requirements_dev.txt index eb6a8d8..7536b7f 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,2 +1,2 @@ pre-commit==2.20.0 - +pdoc3==0.10.0 diff --git a/tests/test_models_utils.py b/tests/test_models_utils.py index 0f81409..7dfa85f 100644 --- a/tests/test_models_utils.py +++ b/tests/test_models_utils.py @@ -16,6 +16,7 @@ checkpoints_dir, downloader, tracer_b7_pretrained, + scene_classifier_pretrained, ) from carvekit.utils.models_utils import fix_seed, suppress_warnings @@ -70,3 +71,4 @@ def test_check_for_exists(): assert deeplab_pretrained().exists() assert basnet_pretrained().exists() assert tracer_b7_pretrained().exists() + assert scene_classifier_pretrained().exists() diff --git a/tests/test_scene_classifier.py b/tests/test_scene_classifier.py new file mode 100644 index 0000000..c0827b9 --- /dev/null +++ b/tests/test_scene_classifier.py @@ -0,0 +1,61 @@ +""" +Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0 +""" + +import torch + +from carvekit.ml.wrap.scene_classifier import SceneClassifier + + +def test_init(): + SceneClassifier() + + +def test_preprocessing(scene_classifier_model, converted_pil_image, black_image_pil): + scene_classifier_model = scene_classifier_model(False) + assert ( + isinstance( + scene_classifier_model.data_preprocessing(converted_pil_image), + torch.FloatTensor, + ) + is True + ) + assert ( + isinstance( + scene_classifier_model.data_preprocessing(black_image_pil), + torch.FloatTensor, + ) + is True + ) + + +def test_inf( + scene_classifier_model, + converted_pil_image, + image_pil, + image_str, + image_path, + black_image_pil, +): + scene_classifier_model = scene_classifier_model(False) + calc_result = scene_classifier_model( + [ + converted_pil_image, + black_image_pil, + image_pil, + image_str, + image_path, + black_image_pil, + ] + ) + assert calc_result[0][0][0] == "soft" + assert calc_result[1][0][0] == "hard" + + +def test_seg_with_fp16( + scene_classifier_model, image_pil, image_str, image_path, black_image_pil +): + scene_classifier_model = scene_classifier_model(True) + scene_classifier_model([image_pil, image_str, image_path, black_image_pil]) diff --git a/tests/test_trimap.py b/tests/test_trimap.py index 47ba728..44a1354 100644 --- a/tests/test_trimap.py +++ b/tests/test_trimap.py @@ -1,6 +1,8 @@ """ Source url: https://github.com/OPHoperHPO/image-background-remove-tool + Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. + License: Apache License 2.0 """ import PIL.Image