Skip to content

Commit

Permalink
Large scale controlnet (#260)
Browse files Browse the repository at this point in the history
PR for running the controlnet pipeline end-to-end on KFP. 

Some observations when doing the pipeline testing: 

- Tested with @ChristiaensBert VM and it runs really nice and much
faster than the public clip service.
- I could not test everything end to end locally since the GPU component
are difficult to run locally -> switched to KFP to leverage the GPU VMs
- I had to rebuild images using the build and tag images in the
`scripts` folder. I think we still need to modify the script to enable
only building specified components since it currently default to all
components in the `components` directory which might take some time to
build
- The local runner does not seem to do the subset checking yet and we
still need to expand the CLI to be able to use the kfp runner (currently
not supported). Although the CLI is really nice overall :)
- Pipeline runs fine and writes the dataset to the hub but fails at the
end since it expects an output manifest. This can be resolved with this
[ticket](#221). We should
prioritize this.

Notes:
- Changed the segmentation to output a segmentation image instead of a
segmentation array since that's the output expected for controlnet
training

Things to do: 
- Estimate how much the job would cost
  • Loading branch information
PhilippeMoussalli authored Jul 6, 2023
1 parent 36db085 commit 699d68f
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 59 deletions.
9 changes: 8 additions & 1 deletion components/download_images/fondant_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,28 @@ args:
timeout:
description: Maximum time (in seconds) to wait when trying to download an image
type: int
default: 10
retries:
description: Number of times to retry downloading an image if it fails.
type: int
default: 0
image_size:
description: Size of the images after resizing.
type: int
default: 256
resize_mode:
description: Resize mode to use. One of "no", "keep_ratio", "center_crop", "border".
type: str
default: 'border'
resize_only_if_bigger:
description: If True, resize only if image is bigger than image_size.
type: bool
default: 'False'
min_image_size:
description: Minimum size of the images.
type: int
default: 0
max_aspect_ratio:
description: Maximum aspect ratio of the images.
type: float
type: float
default: 'inf'
101 changes: 61 additions & 40 deletions components/download_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import traceback
import urllib

import pandas as pd
import dask.dataframe as dd
from resizer import Resizer

from fondant.component import PandasTransformComponent
from fondant.component import DaskTransformComponent

logger = logging.getLogger(__name__)

Expand All @@ -29,7 +29,7 @@ def is_disallowed(headers, user_agent_token, disallowed_header_directives):
else None
)
if (ua_token is None or ua_token == user_agent_token) and any(
x in disallowed_header_directives for x in directives
x in disallowed_header_directives for x in directives
):
return True
except Exception as err:
Expand All @@ -53,9 +53,9 @@ def download_image(url, timeout, user_agent_token, disallowed_header_directives)
)
with urllib.request.urlopen(request, timeout=timeout) as r:
if disallowed_header_directives and is_disallowed(
r.headers,
user_agent_token,
disallowed_header_directives,
r.headers,
user_agent_token,
disallowed_header_directives,
):
return None
img_stream = io.BytesIO(r.read())
Expand All @@ -67,64 +67,85 @@ def download_image(url, timeout, user_agent_token, disallowed_header_directives)


def download_image_with_retry(
url,
*,
timeout,
retries,
resizer,
user_agent_token=None,
disallowed_header_directives=None,
url,
*,
timeout,
retries,
resizer,
user_agent_token=None,
disallowed_header_directives=None,
):
for _ in range(retries + 1):
img_stream = download_image(
url, timeout, user_agent_token, disallowed_header_directives,
)
if img_stream is not None:
# resize the image
return resizer(img_stream)
img_str, width, height = resizer(img_stream)
return img_str, width, height
return None, None, None


class DownloadImagesComponent(PandasTransformComponent):
class DownloadImagesComponent(DaskTransformComponent):
"""Component that downloads images based on URLs."""

def setup(
self,
*,
timeout: int = 10,
retries: int = 0,
image_size: int = 256,
resize_mode: str = "border",
resize_only_if_bigger: bool = False,
min_image_size: int = 0,
max_aspect_ratio: float = float("inf"),
):
def transform(
self,
dataframe: dd.DataFrame,
*,
timeout: int,
retries: int,
image_size: int,
resize_mode: str,
resize_only_if_bigger: bool,
min_image_size: int,
max_aspect_ratio: float,
) -> dd.DataFrame:
"""Function that downloads images from a list of URLs and executes filtering and resizing
Args:
dataframe: Dask dataframe
timeout: Maximum time (in seconds) to wait when trying to download an image.
retries: Number of times to retry downloading an image if it fails.
image_size: Size of the images after resizing.
resize_mode: Resize mode to use. One of "no", "keep_ratio", "center_crop", "border".
resize_only_if_bigger: If True, resize only if image is bigger than image_size.
min_image_size: Minimum size of the images.
max_aspect_ratio: Maximum aspect ratio of the images.
Returns:
Dask dataframe
"""
logger.info("Instantiating resizer...")
self.resizer = Resizer(
resizer = Resizer(
image_size=image_size,
resize_mode=resize_mode,
resize_only_if_bigger=resize_only_if_bigger,
min_image_size=min_image_size,
max_aspect_ratio=max_aspect_ratio,
)
self.timeout = timeout
self.retries = retries

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
dataframe[[
("images", "data"),
("images", "width"),
("images", "height"),
]] = dataframe.apply(

# Remove duplicates from laion retrieval
dataframe = dataframe.drop_duplicates()

dataframe = dataframe.apply(
lambda example: download_image_with_retry(
url=example["images"]["url"],
timeout=self.timeout,
retries=self.retries,
resizer=self.resizer,
url=example.images_url,
timeout=timeout,
retries=retries,
resizer=resizer,
),
axis=1,
result_type="expand",
meta={0: bytes, 1: int, 2: int},
)
dataframe.columns = [
"images_data",
"images_width",
"images_height",
]

# Remove images that could not be fetched
dataframe = dataframe.dropna()

return dataframe

Expand Down
8 changes: 4 additions & 4 deletions components/download_images/src/resizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,20 @@ def __call__(self, img_stream, blurring_bbox_list=None):
original_height, original_width = img.shape[:2]
# check if image is too small
if min(original_height, original_width) < self.min_image_size:
return None, None, None, None, None, "image too small"
return None, None, None
if original_height * original_width > self.max_image_area:
return None, None, None, None, None, "image area too large"
return None, None, None
# check if wrong aspect ratio
if (
max(original_height, original_width)
/ min(original_height, original_width)
> self.max_aspect_ratio
):
return None, None, None, None, None, "aspect ratio too large"
return None, None, None

# check if resizer was defined during init if needed
if blurring_bbox_list is not None and self.blurrer is None:
return None, None, None, None, None, "blurrer not defined"
return None, None, None

# Flag to check if blurring is still needed.
maybe_blur_still_needed = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ args:
aesthetic_weight:
description: Weight of the aesthetic embedding when added to the query, between 0 and 1
type: float
url:
description: The url of the backend clip retrieval service, defaults to the public service
type: str
default: https://knn.laion.ai/knn-service
6 changes: 4 additions & 2 deletions components/prompt_based_laion_retrieval/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def setup(
num_images: int,
aesthetic_score: int,
aesthetic_weight: float,
url: str,
) -> None:
"""
Expand All @@ -30,10 +31,11 @@ def setup(
between 0 and 9.
aesthetic_weight: weight of the aesthetic embedding to add to the query,
between 0 and 1.
url: The url of the backend clip retrieval service, defaults to the public clip url.
"""
self.client = ClipClient(
url="https://knn.laion.ai/knn-service",
indice_name="laion5B-L-14",
url=url,
indice_name="laion5B", #TODO:revert back to laion5b-L-14 after backend correction
num_images=num_images,
aesthetic_score=aesthetic_score,
aesthetic_weight=aesthetic_weight,
Expand Down
13 changes: 9 additions & 4 deletions components/segment_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
logger = logging.getLogger(__name__)


def convert_to_rgb(seg: np.array):
def convert_to_rgb(seg: np.array) -> bytes:
"""
Converts a 2D segmentation to a RGB one which makes it possible to visualize it.
Args:
seg: 2D segmentation map as a NumPy array.
Returns:
color_seg: 3D segmentation map contain RGB values for each pixel.
color_seg: the RGB segmentation map as a binary string
"""
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8,
Expand All @@ -32,9 +32,13 @@ def convert_to_rgb(seg: np.array):
for label, color in enumerate(palette):
color_seg[seg == label, :] = color

color_seg = color_seg.astype(np.uint8).tobytes()
color_seg = color_seg.astype(np.uint8)
image = Image.fromarray(color_seg).convert('RGB')

return color_seg
crop_bytes = io.BytesIO()
image.save(crop_bytes, format="JPEG")

return crop_bytes.getvalue()


def process_image(image: bytes, *, processor: SegformerImageProcessor, device: str) -> torch.Tensor:
Expand All @@ -46,6 +50,7 @@ def process_image(image: bytes, *, processor: SegformerImageProcessor, device: s
processor: The processor object for transforming the image.
device: The device to move the transformed image to.
"""

def load(img: bytes) -> Image:
"""Load the bytestring as an image."""
bytes_ = io.BytesIO(img)
Expand Down
3 changes: 2 additions & 1 deletion components/write_to_hf_hub/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

# Define the schema for the struct using PyArrow
import huggingface_hub
from datasets.features.features import generate_from_arrow_type
from PIL import Image

from fondant.component import WriteComponent
Expand Down Expand Up @@ -71,7 +72,7 @@ def write(
if image_column_names and column_name in image_column_names:
schema_dict[column_name] = datasets.Image()
else:
schema_dict[column_name] = datasets.Value(str(field.type.value))
schema_dict[column_name] = generate_from_arrow_type(field.type.value)

schema = datasets.Features(schema_dict).arrow_schema
dataframe = dataframe[write_columns]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ consumes:
segmentations:
fields:
data:
type: array
items:
type: binary
type: binary

args:
hf_token:
Expand Down
11 changes: 7 additions & 4 deletions examples/pipelines/controlnet-interior-design/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@
)
laion_retrieval_op = ComponentOp.from_registry(
name="prompt_based_laion_retrieval",
arguments={"num_images": 2, "aesthetic_score": 9, "aesthetic_weight": 0.5},
arguments={
"num_images": 2,
"aesthetic_score": 9,
"aesthetic_weight": 0.5,
"url": None,
},
)
download_images_op = ComponentOp.from_registry(
name="download_images",
arguments={
"timeout": 10,
"timeout": 1,
"retries": 0,
"image_size": 512,
"resize_mode": "center_crop",
Expand Down Expand Up @@ -63,8 +68,6 @@
"hf_token": "hf_token",
"image_column_names": ["images_data"],
},
number_of_gpus=1,
node_pool_name="model-inference-pool",
)

pipeline = Pipeline(pipeline_name=pipeline_name, base_path=PipelineConfigs.BASE_PATH)
Expand Down

0 comments on commit 699d68f

Please sign in to comment.