Skip to content

Commit

Permalink
Make download_component concurrent (#354)
Browse files Browse the repository at this point in the history
This PR makes the `download_images` component concurrent.

This is just a quick fix, ideally we rewrite the component to use an
async http client like httpx. I will pick this up as a separate PR.
  • Loading branch information
RobbeSneyders authored Aug 17, 2023
1 parent 9f5d3a6 commit e580189
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 133 deletions.
15 changes: 11 additions & 4 deletions components/download_images/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM --platform=linux/amd64 python:3.8-slim
FROM --platform=linux/amd64 python:3.8-slim as base

# System dependencies
RUN apt-get update && \
Expand All @@ -15,9 +15,16 @@ ARG FONDANT_VERSION=main
RUN pip3 install fondant[aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION}

# Set the working directory to the component folder
WORKDIR /component/src
WORKDIR /component
COPY src/ src/
ENV PYTHONPATH "${PYTHONPATH}:./src"

# Copy over src-files
COPY src/ .
FROM base as test
COPY test_requirements.txt .
RUN pip3 install --no-cache-dir -r test_requirements.txt
COPY tests/ tests/
RUN python -m pytest tests

FROM base
WORKDIR /component/src
ENTRYPOINT ["python", "main.py"]
7 changes: 7 additions & 0 deletions components/download_images/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ If the component is unable to retrieve the image at a URL (for any reason), it w

See [`fondant_component.yaml`](fondant_component.yaml) for a more detailed description on all the input/output parameters.


### Testing

You can run the tests using docker with BuildKit. From this directory, run:
```
docker build . --target test
```
3 changes: 2 additions & 1 deletion components/download_images/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
albumentations==1.3.0
opencv-python-headless>=4.5.5.62,<5
opencv-python-headless>=4.5.5.62,<5
httpx==0.24.1
185 changes: 57 additions & 128 deletions components/download_images/src/main.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,25 @@
"""
This component downloads images based on URLs, and resizes them based on various settings like
minimum image size and aspect ratio.
Some functions here are directly taken from
https://github.com/rom1504/img2dataset/blob/main/img2dataset/downloader.py.
"""
import asyncio
import io
import logging
import traceback
import urllib
import typing as t

import dask.dataframe as dd
from fondant.component import DaskTransformComponent
from fondant.executor import DaskTransformExecutor
import dask
import httpx
import pandas as pd
from fondant.component import PandasTransformComponent
from fondant.executor import PandasTransformExecutor
from resizer import Resizer

logger = logging.getLogger(__name__)

dask.config.set(scheduler='processes')

def is_disallowed(headers, user_agent_token, disallowed_header_directives):
"""Check if HTTP headers contain an X-Robots-Tag directive disallowing usage."""
for values in headers.get_all("X-Robots-Tag", []):
try:
uatoken_directives = values.split(":", 1)
directives = [x.strip().lower() for x in uatoken_directives[-1].split(",")]
ua_token = (
uatoken_directives[0].lower() if len(uatoken_directives) == 2 # noqa: PLR2004
else None
)
if (ua_token is None or ua_token == user_agent_token) and any(
x in disallowed_header_directives for x in directives
):
return True
except Exception as err:
traceback.print_exc()
print(f"Failed to parse X-Robots-Tag: {values}: {err}")
return False


def download_image(url, timeout, user_agent_token, disallowed_header_directives):
"""Download an image with urllib."""
img_stream = None
user_agent_string = (
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
)
if user_agent_token:
user_agent_string += f" (compatible; {user_agent_token}; " \
f"+https://github.com/rom1504/img2dataset)"
try:
request = urllib.request.Request(
url, data=None, headers={"User-Agent": user_agent_string},
)
with urllib.request.urlopen(request, timeout=timeout) as r:
if disallowed_header_directives and is_disallowed(
r.headers,
user_agent_token,
disallowed_header_directives,
):
return None
img_stream = io.BytesIO(r.read())
return img_stream
except Exception:
if img_stream is not None:
img_stream.close()
return None


def download_image_with_retry(
url,
*,
timeout,
retries,
resizer,
user_agent_token=None,
disallowed_header_directives=None,
):
for _ in range(retries + 1):
img_stream = download_image(
url, timeout, user_agent_token, disallowed_header_directives,
)
if img_stream is not None:
# resize the image
img_str, width, height = resizer(img_stream)
return img_str, width, height
return None, None, None


def download_image_with_retry_partition(dataframe, timeout, retries, resizer):
# process a single partition
# TODO make column name more flexible
data = dataframe.images_url.apply(lambda x:
download_image_with_retry(
url=x, timeout=timeout, retries=retries, resizer=resizer,
),
)

# use assign to add values as extra columns
dataframe = dataframe.assign(data=[example[0] for example in data],
width=[example[1] for example in data],
height=[example[2] for example in data],
)

return dataframe


class DownloadImagesComponent(DaskTransformComponent):
class DownloadImagesComponent(PandasTransformComponent):
"""Component that downloads images based on URLs."""

def __init__(self,
Expand Down Expand Up @@ -141,46 +56,60 @@ def __init__(self,
max_aspect_ratio=max_aspect_ratio,
)

def transform(self, dataframe: dd.DataFrame) -> dd.DataFrame:

logger.info(f"Length of the dataframe: {len(dataframe)}")
logger.info("Downloading images...")

# drop width and height columns, as those are going to be
# added later on
dataframe = dataframe.drop(columns=['images_width', 'images_height'])

# create meta
# needs to be a dictionary with keys = column names, values = dtypes of columns
# for each column in the output
meta = dict(zip(dataframe.columns, dataframe.dtypes))
meta["data"] = bytes
meta["width"] = int
meta["height"] = int

dataframe = dataframe.map_partitions(
download_image_with_retry_partition,
timeout=self.timeout,
retries=self.retries,
resizer=self.resizer,
meta=meta,
async def download_image(self, url: str) -> t.Optional[bytes]:
user_agent_string = (
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
)
user_agent_string += " (compatible; +https://github.com/ml6team/fondant)"

transport = httpx.AsyncHTTPTransport(retries=self.retries)
async with httpx.AsyncClient(transport=transport) as client:
try:
response = await client.get(url, timeout=self.timeout,
headers={"User-Agent": user_agent_string})
image_stream = response.content
except Exception as e:
logger.warning(f"Skipping {url}: {e}")
image_stream = None

return image_stream

async def download_and_resize_image(self, id_: str, url: str) \
-> t.Tuple[str, t.Optional[bytes], t.Optional[int], t.Optional[int]]:
image_stream = await self.download_image(url)
if image_stream is not None:
image_stream, width, height = self.resizer(io.BytesIO(image_stream))
else:
image_stream, width, height = None, None, None
return id_, image_stream, width, height

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
logger.info(f"Downloading {len(dataframe)} images...")

results: t.List[t.Tuple[str, bytes, int, int]] = []

async def download_dataframe() -> None:
images = await asyncio.gather(
*[self.download_and_resize_image(id_, url)
for id_, url in zip(dataframe.index, dataframe["images"]["url"])],
)
results.extend(images)

# rename new columns to be conform the spec
logger.info("Renaming columns...")
dataframe = dataframe.rename(columns={"data": "images_data",
"width": "images_width",
"height":"images_height"})
asyncio.run(download_dataframe())

# Remove images that could not be fetched
logger.info("Dropping invalid rows...")
dataframe = dataframe.dropna()
columns = ["id", "data", "width", "height"]
if results:
results_df = pd.DataFrame(results, columns=columns)
else:
results_df = pd.DataFrame(columns=columns)

print("Columns of final dataframe:", dataframe.columns)
results_df = results_df.dropna()
results_df = results_df.set_index("id", drop=True)
results_df.columns = pd.MultiIndex.from_product([["images"], results_df.columns])

return dataframe
return results_df


if __name__ == "__main__":
executor = DaskTransformExecutor.from_args()
executor = PandasTransformExecutor.from_args()
executor.execute(DownloadImagesComponent)
2 changes: 2 additions & 0 deletions components/download_images/test_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pytest==7.4.0
respx==0.20.2
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
70 changes: 70 additions & 0 deletions components/download_images/tests/test_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import io
import os

import pandas as pd
from httpx import Response

from src.main import DownloadImagesComponent


def test_transform(respx_mock):
"""Test the component transform method."""
# Define input data and arguments
# These can be parametrized in the future
ids = [
"a",
"b",
"c",
]
urls = [
"http://host/path.png",
"https://host/path.png",
"https://host/path.jpg",
]
image_size = 256

# Mock httpx to prevent network calls and return test images
image_dir = "tests/images"
images = []
images = [
open(os.path.join(image_dir, image), "rb").read() for image in os.listdir(image_dir) # noqa
]
for url, image in zip(urls, images):
respx_mock.get(url).mock(return_value=Response(200, content=image))

component = DownloadImagesComponent(
timeout=10,
retries=0,
image_size=image_size,
resize_mode="border",
resize_only_if_bigger=False,
min_image_size=0,
max_aspect_ratio=float("inf"),
)

input_dataframe = pd.DataFrame(
{
("images", "url"): urls,
},
index=pd.Index(ids, name="id"),
)

# Use the resizer from the component to generate the expected output images
# But use the image_size argument to validate actual resizing
resized_images = [component.resizer(io.BytesIO(image))[0] for image in images]
expected_dataframe = pd.DataFrame(
{
("images", "data"): resized_images,
("images", "width"): [image_size] * len(ids),
("images", "height"): [image_size] * len(ids),
},
index=pd.Index(ids, name="id"),
)

output_dataframe = component.transform(input_dataframe)

pd.testing.assert_frame_equal(
left=expected_dataframe,
right=output_dataframe,
check_dtype=False,
)

0 comments on commit e580189

Please sign in to comment.