Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Added docstrings in appropriate for pdoc3 format to utils, trimap, pipelines, ml.wrap, ml.file, api #122

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,14 @@ repos:
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
- id: flake8
- repo: local
hooks:
- id: pdoc
name: pdoc
description: 'pdoc3: Auto-generate API documentation for Python projects'
entry: pdoc --html --skip-errors --force -o docs/api carvekit
language: python
language_version: python3
require_serial: true
types: [python]
5 changes: 4 additions & 1 deletion Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ ENV CARVEKIT_PORT '5000'
ENV CARVEKIT_HOST '0.0.0.0'
ENV CARVEKIT_SEGMENTATION_NETWORK 'tracer_b7'
ENV CARVEKIT_PREPROCESSING_METHOD 'none'
ENV CARVEKIT_POSTPROCESSING_METHOD 'fba'
ENV CARVEKIT_POSTPROCESSING_METHOD 'cascade_fba'
ENV CARVEKIT_DEVICE 'cpu'
ENV CARVEKIT_BATCH_SIZE_PRE=5
ENV CARVEKIT_BATCH_SIZE_SEG '5'
ENV CARVEKIT_BATCH_SIZE_MATTING '1'
ENV CARVEKIT_BATCH_SIZE_REFINE '1'
ENV CARVEKIT_SEG_MASK_SIZE '640'
ENV CARVEKIT_MATTING_MASK_SIZE '2048'
ENV CARVEKIT_REFINE_MASK_SIZE '900'
ENV CARVEKIT_AUTH_ENABLE '1'
ENV CARVEKIT_FP16 '0'
ENV CARVEKIT_TRIMAP_PROB_THRESHOLD=231
Expand Down
5 changes: 4 additions & 1 deletion Dockerfile.cuda
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ ENV CARVEKIT_PORT '5000'
ENV CARVEKIT_HOST '0.0.0.0'
ENV CARVEKIT_SEGMENTATION_NETWORK 'tracer_b7'
ENV CARVEKIT_PREPROCESSING_METHOD 'none'
ENV CARVEKIT_POSTPROCESSING_METHOD 'fba'
ENV CARVEKIT_POSTPROCESSING_METHOD 'cascade_fba'
ENV CARVEKIT_DEVICE 'cuda'
ENV CARVEKIT_BATCH_SIZE_PRE=5
ENV CARVEKIT_BATCH_SIZE_SEG '5'
ENV CARVEKIT_BATCH_SIZE_MATTING '1'
ENV CARVEKIT_BATCH_SIZE_REFINE '1'
ENV CARVEKIT_SEG_MASK_SIZE '640'
ENV CARVEKIT_MATTING_MASK_SIZE '2048'
ENV CARVEKIT_REFINE_MASK_SIZE '900'
ENV CARVEKIT_AUTH_ENABLE '1'
ENV CARVEKIT_FP16 '0'
ENV CARVEKIT_TRIMAP_PROB_THRESHOLD=231
Expand Down
93 changes: 74 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ Automated high-quality background removal framework for an image using neural ne

## 🎆 Features:
- High Quality
- Works offline
- Batch Processing
- NVIDIA CUDA and CPU processing
- FP16 inference: Fast inference with low memory usage
- Easy inference
- 100% remove.bg compatible FastAPI HTTP API
- Removes background from hairs
- Automatic best method selection for user's image
- Easy integration with your code
- Models hosted on [HuggingFace](https://huggingface.co/Carve)

## ⛱ Try yourself on [Google Colab](https://colab.research.google.com/github/OPHoperHPO/image-background-remove-tool/blob/master/docs/other/carvekit_try.ipynb)
## ⛓️ How does it work?
Expand Down Expand Up @@ -64,10 +67,17 @@ It can be briefly described as
## 🖼️ Image pre-processing and post-processing methods:
### 🔍 Preprocessing methods:
* `none` - No preprocessing methods used.
> They will be added in the future.
* [`autoscene`](https://huggingface.co/Carve/scene_classifier/) - Automatically detects the scene type using classifier and applies the appropriate model. (default)
* `auto` - Performs in-depth image analysis and more accurately determines the best background removal method. Uses object classifier and scene classifier together.
> ### Notes:
> 1. `AutoScene` and `auto` may override the model and parameters specified by the user without logging.
> So, if you want to use a specific model, make all constant etc., you should disable auto preprocessing methods first!
> 2. At the moment for `auto` method universal models are selected for some specific domains, since the added models are currently not enough for so many types of scenes.
> In the future, when some variety of models is added, auto-selection will be rewritten for the better.
### ✂ Post-processing methods:
* `none` - No post-processing methods used.
* `fba` (default) - This algorithm improves the borders of the image when removing the background from images with hair, etc. using FBA Matting neural network. This method gives the best result in combination with u2net without any preprocessing methods.
* `fba` - This algorithm improves the borders of the image when removing the background from images with hair, etc. using FBA Matting neural network.
* `cascade_fba` (default) - This algorithm refines the segmentation mask using CascadePSP neural network and then applies the FBA algorithm.

## 🏷 Setup for CPU processing:
1. `pip install carvekit --extra-index-url https://download.pytorch.org/whl/cpu`
Expand All @@ -84,12 +94,15 @@ import torch
from carvekit.api.high import HiInterface

# Check doc strings for more information
interface = HiInterface(object_type="hairs-like", # Can be "object" or "hairs-like".
interface = HiInterface(object_type="auto", # Can be "object" or "hairs-like" or "auto"
batch_size_seg=5,
batch_size_pre=5,
batch_size_matting=1,
batch_size_refine=1,
device='cuda' if torch.cuda.is_available() else 'cpu',
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
refine_mask_size=900,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
Expand All @@ -100,42 +113,73 @@ cat_wo_bg.save('2.png')


```

### Analogue of `auto` preprocessing method from cli
``` python
from carvekit.api.autointerface import AutoInterface
from carvekit.ml.wrap.scene_classifier import SceneClassifier
from carvekit.ml.wrap.yolov4 import SimplifiedYoloV4

scene_classifier = SceneClassifier(device="cpu", batch_size=1)
object_classifier = SimplifiedYoloV4(device="cpu", batch_size=1)

interface = AutoInterface(scene_classifier=scene_classifier,
object_classifier=object_classifier,
segmentation_batch_size=1,
postprocessing_batch_size=1,
postprocessing_image_size=2048,
refining_batch_size=1,
refining_image_size=900,
segmentation_device="cpu",
fp16=False,
postprocessing_device="cpu")
images_without_background = interface(['./tests/data/cat.jpg'])
cat_wo_bg = images_without_background[0]
cat_wo_bg.save('2.png')
```
### If you want control everything
``` python
import PIL.Image

from carvekit.api.interface import Interface
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.scene_classifier import SceneClassifier
from carvekit.ml.wrap.cascadepsp import CascadePSP
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.pipelines.postprocessing import CasMattingMethod
from carvekit.pipelines.preprocessing import AutoScene
from carvekit.trimap.generator import TrimapGenerator

# Check doc strings for more information
seg_net = TracerUniversalB7(device='cpu',
batch_size=1)

batch_size=1, fp16=False)
cascade_psp = CascadePSP(device='cpu',
batch_size=1,
input_tensor_size=900,
fp16=False,
processing_accelerate_image_size=2048,
global_step_only=False)
fba = FBAMatting(device='cpu',
input_tensor_size=2048,
batch_size=1)
batch_size=1, fp16=False)

trimap = TrimapGenerator()
trimap = TrimapGenerator(prob_threshold=231, kernel_size=30, erosion_iters=5)

preprocessing = PreprocessingStub()
scene_classifier = SceneClassifier(device='cpu', batch_size=5)
preprocessing = AutoScene(scene_classifier=scene_classifier)

postprocessing = MattingMethod(matting_module=fba,
trimap_generator=trimap,
device='cpu')
postprocessing = CasMattingMethod(
refining_module=cascade_psp,
matting_module=fba,
trimap_generator=trimap,
device='cpu')

interface = Interface(pre_pipe=preprocessing,
post_pipe=postprocessing,
seg_pipe=seg_net)

image = PIL.Image.open('tests/data/cat.jpg')
cat_wo_bg = interface([image])[0]
cat_wo_bg.save('2.png')

cat_wo_bg.save('2.png')
```


Expand All @@ -151,24 +195,35 @@ Usage: carvekit [OPTIONS]
Options:
-i ./2.jpg Path to input file or dir [required]
-o ./2.png Path to output file or dir
--pre none Preprocessing method
--post fba Postprocessing method.
--pre autoscene Preprocessing method
--post cascade_fba Postprocessing method.
--net tracer_b7 Segmentation Network. Check README for more info.

--recursive Enables recursive search for images in a folder
--batch_size 10 Batch Size for list of images to be loaded to
RAM


--batch_size_pre 5 Batch size for list of images to be
processed by preprocessing method network

--batch_size_seg 5 Batch size for list of images to be processed
by segmentation network

--batch_size_mat 1 Batch size for list of images to be processed
by matting network

--batch_size_refine 1 Batch size for list of images to be
processed by refining network

--seg_mask_size 640 The size of the input image for the
segmentation neural network. Use 640 for Tracer B7 and 320 for U2Net

--matting_mask_size 2048 The size of the input image for the matting
neural network.

--refine_mask_size 900 The size of the input image for the refining
neural network.

--trimap_dilation 30 The size of the offset radius from the
object mask in pixels when forming an
unknown area
Expand Down
2 changes: 1 addition & 1 deletion carvekit/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "4.1.0"
version = "4.5.0"
28 changes: 26 additions & 2 deletions carvekit/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
)
@click.option("-i", required=True, type=str, help="Path to input file or dir")
@click.option("-o", default="none", type=str, help="Path to output file or dir")
@click.option("--pre", default="none", type=str, help="Preprocessing method")
@click.option("--post", default="fba", type=str, help="Postprocessing method.")
@click.option("--pre", default="autoscene", type=str, help="Preprocessing method")
@click.option("--post", default="cascade_fba", type=str, help="Postprocessing method.")
@click.option("--net", default="tracer_b7", type=str, help="Segmentation Network")
@click.option(
"--recursive",
Expand All @@ -31,6 +31,12 @@
type=int,
help="Batch Size for list of images to be loaded to RAM",
)
@click.option(
"--batch_size_pre",
default=5,
type=int,
help="Batch size for list of images to be processed by preprocessing method network",
)
@click.option(
"--batch_size_seg",
default=5,
Expand All @@ -43,6 +49,12 @@
type=int,
help="Batch size for list of images to be processed by matting " "network",
)
@click.option(
"--batch_size_refine",
default=1,
type=int,
help="Batch size for list of images to be processed by refining network",
)
@click.option(
"--seg_mask_size",
default=640,
Expand All @@ -55,6 +67,12 @@
type=int,
help="The size of the input image for the matting neural network.",
)
@click.option(
"--refine_mask_size",
default=900,
type=int,
help="The size of the input image for the refining neural network.",
)
@click.option(
"--trimap_dilation",
default=30,
Expand Down Expand Up @@ -89,10 +107,13 @@ def removebg(
net: str,
recursive: bool,
batch_size: int,
batch_size_pre: int,
batch_size_seg: int,
batch_size_mat: int,
batch_size_refine: int,
seg_mask_size: int,
matting_mask_size: int,
refine_mask_size: int,
device: str,
fp16: bool,
trimap_dilation: int,
Expand Down Expand Up @@ -121,12 +142,15 @@ def removebg(
device=device,
batch_size_seg=batch_size_seg,
batch_size_matting=batch_size_mat,
batch_size_refine=batch_size_refine,
seg_mask_size=seg_mask_size,
matting_mask_size=matting_mask_size,
refine_mask_size=refine_mask_size,
fp16=fp16,
trimap_dilation=trimap_dilation,
trimap_erosion=trimap_erosion,
trimap_prob_threshold=trimap_prob_threshold,
batch_size_pre=batch_size_pre,
)

interface = init_interface(interface_config)
Expand Down
Loading