From 7f271246115568b487a6050f193003a38ddfe83f Mon Sep 17 00:00:00 2001 From: Philippe Moussalli Date: Thu, 6 Jul 2023 16:33:40 +0200 Subject: [PATCH] Large scale controlnet (#260) PR for running the controlnet pipeline end-to-end on KFP. Some observations when doing the pipeline testing: - Tested with @ChristiaensBert VM and it runs really nice and much faster than the public clip service. - I could not test everything end to end locally since the GPU component are difficult to run locally -> switched to KFP to leverage the GPU VMs - I had to rebuild images using the build and tag images in the `scripts` folder. I think we still need to modify the script to enable only building specified components since it currently default to all components in the `components` directory which might take some time to build - The local runner does not seem to do the subset checking yet and we still need to expand the CLI to be able to use the kfp runner (currently not supported). Although the CLI is really nice overall :) - Pipeline runs fine and writes the dataset to the hub but fails at the end since it expects an output manifest. This can be resolved with this [ticket](https://github.com/ml6team/fondant/pull/221). We should prioritize this. Notes: - Changed the segmentation to output a segmentation image instead of a segmentation array since that's the output expected for controlnet training Things to do: - Estimate how much the job would cost --- .../download_images/fondant_component.yaml | 9 +- components/download_images/src/main.py | 101 +++++++++++------- components/download_images/src/resizer.py | 8 +- .../fondant_component.yaml | 4 + .../prompt_based_laion_retrieval/src/main.py | 6 +- components/segment_images/src/main.py | 13 ++- components/write_to_hf_hub/src/main.py | 3 +- .../fondant_component.yaml | 4 +- .../controlnet-interior-design/pipeline.py | 11 +- 9 files changed, 100 insertions(+), 59 deletions(-) diff --git a/components/download_images/fondant_component.yaml b/components/download_images/fondant_component.yaml index 6f4262e29..1efaa48d4 100644 --- a/components/download_images/fondant_component.yaml +++ b/components/download_images/fondant_component.yaml @@ -23,21 +23,28 @@ args: timeout: description: Maximum time (in seconds) to wait when trying to download an image type: int + default: 10 retries: description: Number of times to retry downloading an image if it fails. type: int + default: 0 image_size: description: Size of the images after resizing. type: int + default: 256 resize_mode: description: Resize mode to use. One of "no", "keep_ratio", "center_crop", "border". type: str + default: 'border' resize_only_if_bigger: description: If True, resize only if image is bigger than image_size. type: bool + default: 'False' min_image_size: description: Minimum size of the images. type: int + default: 0 max_aspect_ratio: description: Maximum aspect ratio of the images. - type: float \ No newline at end of file + type: float + default: 'inf' \ No newline at end of file diff --git a/components/download_images/src/main.py b/components/download_images/src/main.py index 9b222e3b0..017001e0d 100644 --- a/components/download_images/src/main.py +++ b/components/download_images/src/main.py @@ -10,10 +10,10 @@ import traceback import urllib -import pandas as pd +import dask.dataframe as dd from resizer import Resizer -from fondant.component import PandasTransformComponent +from fondant.component import DaskTransformComponent logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ def is_disallowed(headers, user_agent_token, disallowed_header_directives): else None ) if (ua_token is None or ua_token == user_agent_token) and any( - x in disallowed_header_directives for x in directives + x in disallowed_header_directives for x in directives ): return True except Exception as err: @@ -53,9 +53,9 @@ def download_image(url, timeout, user_agent_token, disallowed_header_directives) ) with urllib.request.urlopen(request, timeout=timeout) as r: if disallowed_header_directives and is_disallowed( - r.headers, - user_agent_token, - disallowed_header_directives, + r.headers, + user_agent_token, + disallowed_header_directives, ): return None img_stream = io.BytesIO(r.read()) @@ -67,13 +67,13 @@ def download_image(url, timeout, user_agent_token, disallowed_header_directives) def download_image_with_retry( - url, - *, - timeout, - retries, - resizer, - user_agent_token=None, - disallowed_header_directives=None, + url, + *, + timeout, + retries, + resizer, + user_agent_token=None, + disallowed_header_directives=None, ): for _ in range(retries + 1): img_stream = download_image( @@ -81,50 +81,71 @@ def download_image_with_retry( ) if img_stream is not None: # resize the image - return resizer(img_stream) + img_str, width, height = resizer(img_stream) + return img_str, width, height return None, None, None -class DownloadImagesComponent(PandasTransformComponent): +class DownloadImagesComponent(DaskTransformComponent): """Component that downloads images based on URLs.""" - def setup( - self, - *, - timeout: int = 10, - retries: int = 0, - image_size: int = 256, - resize_mode: str = "border", - resize_only_if_bigger: bool = False, - min_image_size: int = 0, - max_aspect_ratio: float = float("inf"), - ): + def transform( + self, + dataframe: dd.DataFrame, + *, + timeout: int, + retries: int, + image_size: int, + resize_mode: str, + resize_only_if_bigger: bool, + min_image_size: int, + max_aspect_ratio: float, + ) -> dd.DataFrame: + """Function that downloads images from a list of URLs and executes filtering and resizing + Args: + dataframe: Dask dataframe + timeout: Maximum time (in seconds) to wait when trying to download an image. + retries: Number of times to retry downloading an image if it fails. + image_size: Size of the images after resizing. + resize_mode: Resize mode to use. One of "no", "keep_ratio", "center_crop", "border". + resize_only_if_bigger: If True, resize only if image is bigger than image_size. + min_image_size: Minimum size of the images. + max_aspect_ratio: Maximum aspect ratio of the images. + + Returns: + Dask dataframe + """ logger.info("Instantiating resizer...") - self.resizer = Resizer( + resizer = Resizer( image_size=image_size, resize_mode=resize_mode, resize_only_if_bigger=resize_only_if_bigger, min_image_size=min_image_size, max_aspect_ratio=max_aspect_ratio, ) - self.timeout = timeout - self.retries = retries - - def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: - dataframe[[ - ("images", "data"), - ("images", "width"), - ("images", "height"), - ]] = dataframe.apply( + + # Remove duplicates from laion retrieval + dataframe = dataframe.drop_duplicates() + + dataframe = dataframe.apply( lambda example: download_image_with_retry( - url=example["images"]["url"], - timeout=self.timeout, - retries=self.retries, - resizer=self.resizer, + url=example.images_url, + timeout=timeout, + retries=retries, + resizer=resizer, ), axis=1, result_type="expand", + meta={0: bytes, 1: int, 2: int}, ) + dataframe.columns = [ + "images_data", + "images_width", + "images_height", + ] + + # Remove images that could not be fetched + dataframe = dataframe.dropna() return dataframe diff --git a/components/download_images/src/resizer.py b/components/download_images/src/resizer.py index 386a71d3f..f545a0bf1 100644 --- a/components/download_images/src/resizer.py +++ b/components/download_images/src/resizer.py @@ -174,20 +174,20 @@ def __call__(self, img_stream, blurring_bbox_list=None): original_height, original_width = img.shape[:2] # check if image is too small if min(original_height, original_width) < self.min_image_size: - return None, None, None, None, None, "image too small" + return None, None, None if original_height * original_width > self.max_image_area: - return None, None, None, None, None, "image area too large" + return None, None, None # check if wrong aspect ratio if ( max(original_height, original_width) / min(original_height, original_width) > self.max_aspect_ratio ): - return None, None, None, None, None, "aspect ratio too large" + return None, None, None # check if resizer was defined during init if needed if blurring_bbox_list is not None and self.blurrer is None: - return None, None, None, None, None, "blurrer not defined" + return None, None, None # Flag to check if blurring is still needed. maybe_blur_still_needed = True diff --git a/components/prompt_based_laion_retrieval/fondant_component.yaml b/components/prompt_based_laion_retrieval/fondant_component.yaml index 09cdb630d..5fa3bf331 100644 --- a/components/prompt_based_laion_retrieval/fondant_component.yaml +++ b/components/prompt_based_laion_retrieval/fondant_component.yaml @@ -25,3 +25,7 @@ args: aesthetic_weight: description: Weight of the aesthetic embedding when added to the query, between 0 and 1 type: float + url: + description: The url of the backend clip retrieval service, defaults to the public service + type: str + default: https://knn.laion.ai/knn-service \ No newline at end of file diff --git a/components/prompt_based_laion_retrieval/src/main.py b/components/prompt_based_laion_retrieval/src/main.py index 5109e94e5..6dbc39a57 100644 --- a/components/prompt_based_laion_retrieval/src/main.py +++ b/components/prompt_based_laion_retrieval/src/main.py @@ -21,6 +21,7 @@ def setup( num_images: int, aesthetic_score: int, aesthetic_weight: float, + url: str, ) -> None: """ @@ -30,10 +31,11 @@ def setup( between 0 and 9. aesthetic_weight: weight of the aesthetic embedding to add to the query, between 0 and 1. + url: The url of the backend clip retrieval service, defaults to the public clip url. """ self.client = ClipClient( - url="https://knn.laion.ai/knn-service", - indice_name="laion5B-L-14", + url=url, + indice_name="laion5B", #TODO:revert back to laion5b-L-14 after backend correction num_images=num_images, aesthetic_score=aesthetic_score, aesthetic_weight=aesthetic_weight, diff --git a/components/segment_images/src/main.py b/components/segment_images/src/main.py index b666127b3..89e9193ef 100644 --- a/components/segment_images/src/main.py +++ b/components/segment_images/src/main.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -def convert_to_rgb(seg: np.array): +def convert_to_rgb(seg: np.array) -> bytes: """ Converts a 2D segmentation to a RGB one which makes it possible to visualize it. @@ -23,7 +23,7 @@ def convert_to_rgb(seg: np.array): seg: 2D segmentation map as a NumPy array. Returns: - color_seg: 3D segmentation map contain RGB values for each pixel. + color_seg: the RGB segmentation map as a binary string """ color_seg = np.zeros( (seg.shape[0], seg.shape[1], 3), dtype=np.uint8, @@ -32,9 +32,13 @@ def convert_to_rgb(seg: np.array): for label, color in enumerate(palette): color_seg[seg == label, :] = color - color_seg = color_seg.astype(np.uint8).tobytes() + color_seg = color_seg.astype(np.uint8) + image = Image.fromarray(color_seg).convert('RGB') - return color_seg + crop_bytes = io.BytesIO() + image.save(crop_bytes, format="JPEG") + + return crop_bytes.getvalue() def process_image(image: bytes, *, processor: SegformerImageProcessor, device: str) -> torch.Tensor: @@ -46,6 +50,7 @@ def process_image(image: bytes, *, processor: SegformerImageProcessor, device: s processor: The processor object for transforming the image. device: The device to move the transformed image to. """ + def load(img: bytes) -> Image: """Load the bytestring as an image.""" bytes_ = io.BytesIO(img) diff --git a/components/write_to_hf_hub/src/main.py b/components/write_to_hf_hub/src/main.py index a81bcb5c9..c3022b234 100644 --- a/components/write_to_hf_hub/src/main.py +++ b/components/write_to_hf_hub/src/main.py @@ -8,6 +8,7 @@ # Define the schema for the struct using PyArrow import huggingface_hub +from datasets.features.features import generate_from_arrow_type from PIL import Image from fondant.component import WriteComponent @@ -71,7 +72,7 @@ def write( if image_column_names and column_name in image_column_names: schema_dict[column_name] = datasets.Image() else: - schema_dict[column_name] = datasets.Value(str(field.type.value)) + schema_dict[column_name] = generate_from_arrow_type(field.type.value) schema = datasets.Features(schema_dict).arrow_schema dataframe = dataframe[write_columns] diff --git a/examples/pipelines/controlnet-interior-design/components/write_to_hub_controlnet/fondant_component.yaml b/examples/pipelines/controlnet-interior-design/components/write_to_hub_controlnet/fondant_component.yaml index efb253159..4915810f0 100644 --- a/examples/pipelines/controlnet-interior-design/components/write_to_hub_controlnet/fondant_component.yaml +++ b/examples/pipelines/controlnet-interior-design/components/write_to_hub_controlnet/fondant_component.yaml @@ -16,9 +16,7 @@ consumes: segmentations: fields: data: - type: array - items: - type: binary + type: binary args: hf_token: diff --git a/examples/pipelines/controlnet-interior-design/pipeline.py b/examples/pipelines/controlnet-interior-design/pipeline.py index 33b4d9054..ecb1c2e06 100644 --- a/examples/pipelines/controlnet-interior-design/pipeline.py +++ b/examples/pipelines/controlnet-interior-design/pipeline.py @@ -20,12 +20,17 @@ ) laion_retrieval_op = ComponentOp.from_registry( name="prompt_based_laion_retrieval", - arguments={"num_images": 2, "aesthetic_score": 9, "aesthetic_weight": 0.5}, + arguments={ + "num_images": 2, + "aesthetic_score": 9, + "aesthetic_weight": 0.5, + "url": None, + }, ) download_images_op = ComponentOp.from_registry( name="download_images", arguments={ - "timeout": 10, + "timeout": 1, "retries": 0, "image_size": 512, "resize_mode": "center_crop", @@ -63,8 +68,6 @@ "hf_token": "hf_token", "image_column_names": ["images_data"], }, - number_of_gpus=1, - node_pool_name="model-inference-pool", ) pipeline = Pipeline(pipeline_name=pipeline_name, base_path=PipelineConfigs.BASE_PATH)