Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large scale controlnet #260

Merged
merged 14 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions components/download_images/fondant_component.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Download images
description: Component that downloads images based on URLs
image: ghcr.io/ml6team/download_images:dev
image: ghcr.io/ml6team/download_images:latest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
image: ghcr.io/ml6team/download_images:latest
image: ghcr.io/ml6team/download_images:dev

The images on main should be fixed to dev, which corresponds to the latest main version. latest corresponds to the latest release.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right forget to revert this back!


consumes:
images:
Expand All @@ -23,21 +23,28 @@ args:
timeout:
description: Maximum time (in seconds) to wait when trying to download an image
type: int
default: 10
retries:
description: Number of times to retry downloading an image if it fails.
type: int
default: 0
image_size:
description: Size of the images after resizing.
type: int
default: 256
resize_mode:
description: Resize mode to use. One of "no", "keep_ratio", "center_crop", "border".
type: str
default: 'border'
resize_only_if_bigger:
description: If True, resize only if image is bigger than image_size.
type: bool
default: 'False'
min_image_size:
description: Minimum size of the images.
type: int
default: 0
max_aspect_ratio:
description: Maximum aspect ratio of the images.
type: float
type: float
default: 'inf'
101 changes: 61 additions & 40 deletions components/download_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import traceback
import urllib

import pandas as pd
import dask.dataframe as dd
from resizer import Resizer

from fondant.component import PandasTransformComponent
from fondant.component import DaskTransformComponent

logger = logging.getLogger(__name__)

Expand All @@ -29,7 +29,7 @@ def is_disallowed(headers, user_agent_token, disallowed_header_directives):
else None
)
if (ua_token is None or ua_token == user_agent_token) and any(
x in disallowed_header_directives for x in directives
x in disallowed_header_directives for x in directives
):
return True
except Exception as err:
Expand All @@ -53,9 +53,9 @@ def download_image(url, timeout, user_agent_token, disallowed_header_directives)
)
with urllib.request.urlopen(request, timeout=timeout) as r:
if disallowed_header_directives and is_disallowed(
r.headers,
user_agent_token,
disallowed_header_directives,
r.headers,
user_agent_token,
disallowed_header_directives,
):
return None
img_stream = io.BytesIO(r.read())
Expand All @@ -67,64 +67,85 @@ def download_image(url, timeout, user_agent_token, disallowed_header_directives)


def download_image_with_retry(
url,
*,
timeout,
retries,
resizer,
user_agent_token=None,
disallowed_header_directives=None,
url,
*,
timeout,
retries,
resizer,
user_agent_token=None,
disallowed_header_directives=None,
):
for _ in range(retries + 1):
img_stream = download_image(
url, timeout, user_agent_token, disallowed_header_directives,
)
if img_stream is not None:
# resize the image
return resizer(img_stream)
img_str, width, height = resizer(img_stream)
return img_str, width, height
return None, None, None


class DownloadImagesComponent(PandasTransformComponent):
class DownloadImagesComponent(DaskTransformComponent):
"""Component that downloads images based on URLs."""

def setup(
self,
*,
timeout: int = 10,
retries: int = 0,
image_size: int = 256,
resize_mode: str = "border",
resize_only_if_bigger: bool = False,
min_image_size: int = 0,
max_aspect_ratio: float = float("inf"),
):
def transform(
self,
dataframe: dd.DataFrame,
*,
timeout: int,
retries: int,
image_size: int,
resize_mode: str,
resize_only_if_bigger: bool,
min_image_size: int,
max_aspect_ratio: float,
) -> dd.DataFrame:
"""Function that downloads images from a list of URLs and executes filtering and resizing
Args:
dataframe: Dask dataframe
timeout: Maximum time (in seconds) to wait when trying to download an image.
retries: Number of times to retry downloading an image if it fails.
image_size: Size of the images after resizing.
resize_mode: Resize mode to use. One of "no", "keep_ratio", "center_crop", "border".
resize_only_if_bigger: If True, resize only if image is bigger than image_size.
min_image_size: Minimum size of the images.
max_aspect_ratio: Maximum aspect ratio of the images.

Returns:
Dask dataframe
"""
logger.info("Instantiating resizer...")
self.resizer = Resizer(
resizer = Resizer(
image_size=image_size,
resize_mode=resize_mode,
resize_only_if_bigger=resize_only_if_bigger,
min_image_size=min_image_size,
max_aspect_ratio=max_aspect_ratio,
)
self.timeout = timeout
self.retries = retries

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
dataframe[[
("images", "data"),
("images", "width"),
("images", "height"),
]] = dataframe.apply(

# Remove duplicates from laion retrieval
dataframe = dataframe.drop_duplicates()

dataframe = dataframe.apply(
lambda example: download_image_with_retry(
url=example["images"]["url"],
timeout=self.timeout,
retries=self.retries,
resizer=self.resizer,
url=example.images_url,
timeout=timeout,
retries=retries,
resizer=resizer,
),
axis=1,
result_type="expand",
meta={0: bytes, 1: int, 2: int},
)
dataframe.columns = [
"images_data",
"images_width",
"images_height",
]

# Remove images that could not be fetched
dataframe = dataframe.dropna()

return dataframe

Expand Down
8 changes: 4 additions & 4 deletions components/download_images/src/resizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,20 @@ def __call__(self, img_stream, blurring_bbox_list=None):
original_height, original_width = img.shape[:2]
# check if image is too small
if min(original_height, original_width) < self.min_image_size:
return None, None, None, None, None, "image too small"
return None, None, None
if original_height * original_width > self.max_image_area:
return None, None, None, None, None, "image area too large"
return None, None, None
# check if wrong aspect ratio
if (
max(original_height, original_width)
/ min(original_height, original_width)
> self.max_aspect_ratio
):
return None, None, None, None, None, "aspect ratio too large"
return None, None, None

# check if resizer was defined during init if needed
if blurring_bbox_list is not None and self.blurrer is None:
return None, None, None, None, None, "blurrer not defined"
return None, None, None

# Flag to check if blurring is still needed.
maybe_blur_still_needed = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ args:
aesthetic_weight:
description: Weight of the aesthetic embedding when added to the query, between 0 and 1
type: float
url:
description: The url of the backend clip retrieval service, defaults to the public service
type: str
default: https://knn.laion.ai/knn-service
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem like the main.py script has a default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The defaults defined here translate internally to defaults defined in the argparser since kfp always requires a given provided argument if specified and cannot be empty.

parser.add_argument("--url", default="https://knn.laion.ai/knn-service")

The values defined in the argument parser generally take precedence over the default values defined in the main.py file so adding them there can be a bit misleading (e.g. if the user attempts to change them, the default values won't be used).

6 changes: 4 additions & 2 deletions components/prompt_based_laion_retrieval/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def setup(
num_images: int,
aesthetic_score: int,
aesthetic_weight: float,
url: str,
) -> None:
"""

Expand All @@ -30,10 +31,11 @@ def setup(
between 0 and 9.
aesthetic_weight: weight of the aesthetic embedding to add to the query,
between 0 and 1.
url: The url of the backend clip retrieval service, defaults to the public clip url.
"""
self.client = ClipClient(
url="https://knn.laion.ai/knn-service",
indice_name="laion5B-L-14",
url=url,
indice_name="laion5B", #TODO:revert back to laion5b-L-14 after backend correction
num_images=num_images,
aesthetic_score=aesthetic_score,
aesthetic_weight=aesthetic_weight,
Expand Down
13 changes: 9 additions & 4 deletions components/segment_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
logger = logging.getLogger(__name__)


def convert_to_rgb(seg: np.array):
def convert_to_rgb(seg: np.array) -> bytes:
"""
Converts a 2D segmentation to a RGB one which makes it possible to visualize it.

Args:
seg: 2D segmentation map as a NumPy array.

Returns:
color_seg: 3D segmentation map contain RGB values for each pixel.
color_seg: the RGB segmentation map as a binary string
"""
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8,
Expand All @@ -32,9 +32,13 @@ def convert_to_rgb(seg: np.array):
for label, color in enumerate(palette):
color_seg[seg == label, :] = color

color_seg = color_seg.astype(np.uint8).tobytes()
color_seg = color_seg.astype(np.uint8)
image = Image.fromarray(color_seg).convert('RGB')

return color_seg
crop_bytes = io.BytesIO()
image.save(crop_bytes, format="JPEG")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually save the image to disk?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this just saves it to crop_bytes which is a BytesIO object (in-memory buffer to store the image in binary format)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok makes sense, thanks!


return crop_bytes.getvalue()


def process_image(image: bytes, *, processor: SegformerImageProcessor, device: str) -> torch.Tensor:
Expand All @@ -46,6 +50,7 @@ def process_image(image: bytes, *, processor: SegformerImageProcessor, device: s
processor: The processor object for transforming the image.
device: The device to move the transformed image to.
"""

def load(img: bytes) -> Image:
"""Load the bytestring as an image."""
bytes_ = io.BytesIO(img)
Expand Down
3 changes: 2 additions & 1 deletion components/write_to_hf_hub/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

# Define the schema for the struct using PyArrow
import huggingface_hub
from datasets.features.features import generate_from_arrow_type
from PIL import Image

from fondant.component import WriteComponent
Expand Down Expand Up @@ -71,7 +72,7 @@ def write(
if image_column_names and column_name in image_column_names:
schema_dict[column_name] = datasets.Image()
else:
schema_dict[column_name] = datasets.Value(str(field.type.value))
schema_dict[column_name] = generate_from_arrow_type(field.type.value)

schema = datasets.Features(schema_dict).arrow_schema
dataframe = dataframe[write_columns]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ consumes:
segmentations:
fields:
data:
type: array
items:
type: binary
type: binary

args:
hf_token:
Expand Down
11 changes: 7 additions & 4 deletions examples/pipelines/controlnet-interior-design/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@
)
laion_retrieval_op = ComponentOp.from_registry(
name="prompt_based_laion_retrieval",
arguments={"num_images": 2, "aesthetic_score": 9, "aesthetic_weight": 0.5},
arguments={
"num_images": 2,
"aesthetic_score": 9,
"aesthetic_weight": 0.5,
"url": None,
},
)
download_images_op = ComponentOp.from_registry(
name="download_images",
arguments={
"timeout": 10,
"timeout": 1,
"retries": 0,
"image_size": 512,
"resize_mode": "center_crop",
Expand Down Expand Up @@ -63,8 +68,6 @@
"hf_token": "hf_token",
"image_column_names": ["images_data"],
},
number_of_gpus=1,
node_pool_name="model-inference-pool",
)

pipeline = Pipeline(pipeline_name=pipeline_name, base_path=PipelineConfigs.BASE_PATH)
Expand Down