Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warc support #48

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 128 additions & 33 deletions cc2dataset/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


from fastwarc.warc import ArchiveIterator, WarcRecordType
from resiliparse.parse.html import HTMLTree
from resiliparse.extract.html2text import extract_plain_text
import simdjson
import fsspec
from timeit import default_timer as timer
Expand All @@ -17,14 +19,15 @@
import time
from .spark_session_builder import build_spark_session
from io import BytesIO
from urllib.parse import urljoin
import fire


def valid_video_link(link):
valid_http = link.get("url", "").startswith("http")
valid_video = any(
link.get("url", "").endswith(ext) for ext in [".avi", ".mp4", ".mkv", ".webm", ".mov", ".mpg", ".mpeg", ".m4v"]
)
return valid_http and valid_video
return valid_video


def extract_video_from_links(links):
Expand Down Expand Up @@ -54,8 +57,6 @@ def extract_video_from_links(links):


def valid_text_link(link):
if not link.get("url", "").startswith("http"):
return False
splits = link.get("url", "").split(".")
if len(splits) < 2:
return False
Expand All @@ -70,9 +71,8 @@ def extract_text_from_links(links):


def valid_audio_link(link):
valid_http = link.get("url", "").startswith("http")
valid_audio = any(link.get("url", "").endswith(ext) for ext in [".ogg", ".wav", ".mp3", ".flac", ".m4a"])
return valid_http and valid_audio
return valid_audio


def extract_audio_from_links(links):
Expand All @@ -84,8 +84,7 @@ def extract_audio_from_links(links):
def valid_image_link(link):
valid_path = link.get("path", "") == "IMG@/src"
valid_alt = len(link.get("alt", "")) > 0
valid_http = link.get("url", "").startswith("http")
return valid_path and valid_http and valid_alt
return valid_path and valid_alt


def extract_image_from_links(links):
Expand All @@ -94,21 +93,73 @@ def extract_image_from_links(links):
return filtered_links


def valid_image_only_link(link):
valid_path = link.get("path", "") == "IMG@/src"
return valid_path


def extract_image_only_from_links(links):
"""Extract image from links even when no caption is present"""
filtered_links = [{"url": link["url"], "alt": link.get("alt", "")} for link in links if valid_image_only_link(link)]
return filtered_links


def make_link_absolute(url, base_url):
if url.startswith("http://") or url.startswith("https://"):
return url
try:
return urljoin(base_url, url)
except ValueError:
return url


def make_links_absolute(links, base_url):
return [{"url": make_link_absolute(link["url"], base_url), "alt": link["alt"]} for link in links]


def extract_documents_from_links(links, document_type):
"""Extract documents from links ; this function returns a list of dict {"alt": ..., "url": ...}"""

if document_type == "image":
return extract_image_from_links(links)
elif document_type == "image_only":
return extract_image_only_from_links(links)
elif document_type == "audio":
return extract_audio_from_links(links)
elif document_type == "text":
return extract_text_from_links(links)
elif document_type == "video":
return extract_video_from_links(links)
elif document_type == "video_platform":
return extract_video_platform_from_links(links)
else:
raise ValueError(f"Unknown document type {document_type}")


def extract_documents_from_warc(path):
"""Extract documents from WARC"""
with fsspec.open(path, mode="rb", compression="gzip") as f:
try:
for record in ArchiveIterator(f):
try:
page_url = str(record.headers["WARC-Target-URI"])
tree = HTMLTree.parse_from_bytes(record.reader.read())

for ele in tree.body.get_elements_by_tag_name("iframe"):
alt = extract_plain_text(str(ele.parent))
url = urljoin(page_url, ele.getattr("src"))
if url not in [None, "anout:blank"]:

yield (str(hashlib.md5((alt + url).encode()).hexdigest()), url, alt, path, page_url)

except: # pylint: disable=bare-except
continue

except Exception as e: # pylint: disable=broad-except
logger.info(e)
logger.info("A shard failed to parse")


def extract_documents_from_wat(stream, document_type):
"""Extract document from stream"""
all_links = []
Expand All @@ -131,11 +182,30 @@ def extract_documents_from_wat(stream, document_type):
continue

links = metadata["Links"]
cc_filename = record_data["Container"]["Filename"]
page_url = envelope["WARC-Header-Metadata"]["WARC-Target-URI"]
# extract base URL to resolve relative URLs
base_url = envelope["WARC-Header-Metadata"]["WARC-Target-URI"]
if "Head" in metadata and "Base" in metadata["Head"]:
try:
base_url = urljoin(base_url, metadata["Head"]["Base"])
except ValueError:
pass

filtered_links = extract_documents_from_links(links, document_type)
filtered_links = make_links_absolute(filtered_links, base_url)
filtered_links = [
link
for link in filtered_links
if link["url"].startswith("http://") or link["url"].startswith("https://")
]
for link in filtered_links:
link["uid"] = str(hashlib.md5((link["alt"] + link["url"]).encode()).hexdigest())
link["cc_filename"] = cc_filename
link["page_url"] = page_url
all_links.extend(filtered_links)
# if len(all_links) > 100:
# return all_links
except Exception as e: # pylint: disable=broad-except
logger.info(e)
logger.info("A shard failed to parse")
Expand All @@ -146,32 +216,47 @@ def extract_documents_from_wat(stream, document_type):

def process_wat(path, document_type):
"""Process a single wat file"""

ext = path.replace(".gz", "").split(".")[-1].replace("/", "").lower()
if ext not in ["wat", "warc"]:
raise ValueError(f"Extension can only be either 'wat' or 'warc', you provied {ext}")

begin_read = timer()
with fsspec.open(path, "rb") as f:
for i in range(10):
try:
tf = BytesIO(f.read())
break
except Exception as ex: # pylint: disable=broad-except
if i == 9:
logger.info("failed 10 times, skipping ", path)
return
logger.info(ex)
logger.info(f"retrying reading {i}/10")
time.sleep(1)

for e in extract_documents_from_wat(tf, document_type):
yield (e["uid"], e["url"], e["alt"])
if ext == "warc" and document_type == "iframe":
for e in extract_documents_from_warc(path):
yield e
else:
with fsspec.open(path, mode="rb", compression="gzip") as f:
for i in range(10):
try:
tf = BytesIO(f.read())
break
except Exception as ex: # pylint: disable=broad-except
if i == 9:

logger.info("failed 10 times, skipping ", path)
return
logger.info(ex)
logger.info(f"retrying reading {i}/10")
time.sleep(1)

if ext == "wat" and document_type != "iframe":
for e in extract_documents_from_wat(tf, document_type):
yield (e["uid"], e["url"], e["alt"], e["cc_filename"], e["page_url"])
elif ext == "wat" and document_type == "iframe":
raise ValueError(f"Document type {document_type} is not suppeorted by file type {ext}")
else:
raise ValueError(f"Unknown document type {document_type} and file type {ext}")
end_read = timer()
tot_read_time = end_read - begin_read
logger.info(f"Took {tot_read_time} to parse")


def get_cc_wat_links(source_cc_protocol):
def get_cc_wat_links(source_cc_protocol, ext):
"""Get cc wat links"""
if source_cc_protocol == "s3":
fs, p = fsspec.core.url_to_fs("s3://commoncrawl/crawl-data/")
links = ["s3://" + e for e in fs.glob(p + "/*/wat.paths.gz")]
links = ["s3://" + e for e in fs.glob(p + f"/*/{ext}.paths.gz")]
return links
elif source_cc_protocol == "http":
fs, p = fsspec.core.url_to_fs("https://commoncrawl.org/the-data/get-started/")
Expand All @@ -183,7 +268,8 @@ def get_cc_wat_links(source_cc_protocol):
e.split(" ")[0].replace("<li>s3://commoncrawl/", "https://data.commoncrawl.org/").replace("<wbr>", "")
for e in l
]
l = [(e + "/wat.paths.gz").replace("//wat", "/wat") for e in l]
l = [(e + f"/{ext}.paths.gz").replace(f"//{ext}", f"/{ext}") for e in l]

return l
else:
raise ValueError(f"Unknown protocol {source_cc_protocol}")
Expand All @@ -195,9 +281,9 @@ def read_wat_index_file(wat_index):
return wats


def read_wat_index_files(shard_count, wat_count, source_cc_protocol):
def read_wat_index_files(shard_count, wat_count, source_cc_protocol, ext):
"""Read all wat index files"""
cc_wat_links = get_cc_wat_links(source_cc_protocol)
cc_wat_links = get_cc_wat_links(source_cc_protocol, ext)
if shard_count is not None:
cc_wat_links = cc_wat_links[-shard_count:] # pylint: disable=invalid-unary-operand-type
all_wats = []
Expand Down Expand Up @@ -243,7 +329,8 @@ def extract(x):
yield from process_wat(prefix + x[0], document_type)

output = wat_rdd.mapPartitions(extract)
df = output.toDF(["uid", "url", "alt"])
df = output.toDF(["uid", "url", "alt", "cc_filename", "page_url"])
df = df.na.drop(subset=["url"]).filter(df.url != "about:blank")

deduplicate_repartition_count(df, output_path, wat_count, spark, shuffle)

Expand Down Expand Up @@ -307,9 +394,15 @@ def cc2dataset(
spark_builder=None,
document_type="image",
source_cc_protocol="s3",
file_type="wat",
):
"""Convert common crawl to image caption set"""

file_type = file_type.lower()

if file_type not in ["wat", "warc"]:
raise ValueError("File type can only be either 'wat' or 'warc'")

if resume is not None and multipart is None:
raise ValueError("Cannot resume without multipart")

Expand All @@ -323,7 +416,9 @@ def cc2dataset(
logger.info(f"Writing in: {output_path}")

if spark_builder is None:
spark_builder = lambda: build_spark_session(master, num_cores, mem_gb)

def spark_builder():
return build_spark_session(master, num_cores, mem_gb)

def build_spark():
spark = SparkSession.getActiveSession()
Expand All @@ -332,12 +427,12 @@ def build_spark():
return spark_builder()

if resume is None:
wat_index_files = read_wat_index_files(wat_index_count, wat_count, source_cc_protocol)
wat_index_files = read_wat_index_files(wat_index_count, wat_count, source_cc_protocol, file_type)
# write wat index files to disk in output_path with fsspec
with fsspec.open(f"{output_path}/wat_index_files.txt", "w", encoding="utf8") as f:
with fsspec.open(f"{output_path}/{file_type}_index_files.txt", "w", encoding="utf8") as f:
f.write("\n".join(wat_index_files))
else:
with fsspec.open(f"{output_path}/wat_index_files.txt", "r", encoding="utf8") as f:
with fsspec.open(f"{output_path}/{file_type}_index_files.txt", "r", encoding="utf8") as f:
wat_index_files = f.read().splitlines()

if multipart is None:
Expand Down
3 changes: 1 addition & 2 deletions cc2dataset/spark_session_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def aws_ec2_s3_spark_session(master, num_cores=128, mem_gb=256):
"spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.3.1,org.apache.spark:spark-hadoop-cloud_2.13:3.3.1"
)
# change to the appropriate auth method, see https://hadoop.apache.org/docs/stable/hadoop-aws/tools/hadoop-aws/index.html
.config("spark.hadoop.fs.s3a.aws.credentials.provider", "com.amazonaws.auth.InstanceProfileCredentialsProvider")
# .config("spark.hadoop.fs.s3a.aws.credentials.provider", "com.amazonaws.auth.InstanceProfileCredentialsProvider")
# ton of options to try and make s3a run faster
.config("spark.hadoop.fs.s3a.threads.max", "512")
.config("spark.hadoop.fs.s3a.connection.maximum", "2048")
Expand All @@ -65,7 +65,6 @@ def aws_ec2_s3_spark_session(master, num_cores=128, mem_gb=256):
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.hadoop.fs.s3a.experimental.input.fadvise", "random")
.config("spark.hadoop.fs.s3a.block.size", "2M")
.config("spark.hadoop.fs.s3a.connection.ssl.enabled", "false")
.config("spark.hadoop.fs.s3a.fast.buffer.size", "100M")
.config("spark.hadoop.fs.s3a.fast.upload.buffer", "array")
.config("spark.hadoop.fs.s3a.bucket.all.committer.magic.enabled", "true")
Expand Down
2 changes: 1 addition & 1 deletion examples/single_warc_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
else:
url = "https://data.commoncrawl.org/" + wat

results = process_wat(url, "image")
results = process_wat(url, "video_platform")
df = pd.DataFrame(results, columns=["uid", "url", "alt"])
df.to_parquet(os.getcwd() + "/output.parquet")
print(df)
18 changes: 18 additions & 0 deletions examples/single_warc_iframe_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from cc2dataset import process_wat
import os
import pandas as pd

if __name__ == "__main__":
from_s3 = False
wat = (
"crawl-data/CC-MAIN-2023-06/segments/1674764494974.98/warc/CC-MAIN-20230127065356-20230127095356-00752.warc.gz"
)
if from_s3:
url = "s3://commoncrawl/" + wat
else:
url = "https://data.commoncrawl.org/" + wat

results = process_wat(url, "iframe")
df = pd.DataFrame(results, columns=["uid", "url", "alt", "cc_filename", "page_url"])
df.to_parquet(os.getcwd() + "/output.parquet")
print(df)
Loading