Skip to content

Commit

Permalink
Support nested input format for webdataset (#84)
Browse files Browse the repository at this point in the history
* Support nested input format for webdataset

* update

* comment

* Fix black

* make comment

* fix lint

* remove trailing

* black

---------

Co-authored-by: Maciej Kilian <iejmac@ip-172-64-54-169.us-west-2.compute.internal>
Co-authored-by: Maciej Kilian <iejmac@ip-26-0-147-155.us-west-2.compute.internal>
  • Loading branch information
3 people authored Oct 31, 2023
1 parent 7de9c3a commit b5df2c7
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 83 deletions.
185 changes: 109 additions & 76 deletions clip_video_encode/clip_video_encode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""encode video with CLIP"""
import re
import sys
import time

Expand Down Expand Up @@ -27,6 +28,17 @@ def _convert_image_to_rgb(image):
return image.convert("RGB")


def extract_braceexpand_values(be_template, path):
# Construct regex pattern based on the braceexpand string
reg_template = re.sub(r"\{.*?\}", r"(\\d+)", be_template)
reg_template = reg_template.replace(".", r"\.")

pattern = re.compile(reg_template)
match = pattern.match(path)

return list(match.groups())


def clip_video_encode(
src,
dest="",
Expand All @@ -46,6 +58,7 @@ def clip_video_encode(
frame_tokenization_strategy="none",
generated_caption_key="generated_caption", # this will put it in json, make this 'caption' if you want it in txt
pass_through_keys="mp4,txt,json",
vid_ext="mp4",
caption_similarity=False,
img_size=224,
):
Expand Down Expand Up @@ -120,10 +133,19 @@ def clip_video_encode(
fs.mkdir(output_path)
done_shards = set()
else:
done_shards = set(int(x.split("/")[-1].split("_")[0]) for x in fs.glob(output_path + "/*.tar"))
done_shards = set(x.split("/")[-1].split("_")[0] for x in fs.glob(output_path + "/*.tar"))

print(f"Removing {len(done_shards)} done_shards from processing queue...")
s_ids = [s.split("/")[-1][: -len(".tar")] for s in shards]

# TODO: finish this
# def get_sids(be_template):
# shards = list(braceexpand.braceexpand(be_template))
# values = extract_braceexpand_values(be_template, path)
# max_values = extract_braceexpand_values(be_template, list(braceexpand.braceexpand(be_template))[-1])
# for i in range(len(values)):
# values[i] = values[i].zfill(len(max_values[i]))
# write_shard_id = "".join(values)
# return write_shard_id
shards = [s for s_id, s in zip(s_ids, shards) if int(s_id) not in done_shards]

starting_shard_id = 0
Expand Down Expand Up @@ -200,62 +222,92 @@ def clip_video_encode(
encode_chunk(frames, ind_dict, writer, fm, meta, ids, use_dst_name, device, input_format=input_format)
else: # WebDataset shard logic
for shard in shards:
# try:
times = {}
t = time.time()
with tempfile.TemporaryDirectory(prefix=f"worker_{global_rank}_") as tempdir:
os.chmod(tempdir, 0o777) # This lets subprocesses from v2np read files in the tempdir
folder = "/".join(shard.split("/")[0:-1])
fs, output_path = fsspec.core.url_to_fs(folder)

shard_id = shard.split("/")[-1]
tar_bytes = io.BytesIO(fs.open(f"{output_path}/{shard_id}").read())
with tarfile.open(fileobj=tar_bytes) as tar:
tar.extractall(tempdir)
writer.create_shard(shard_id=int(shard_id.split(".tar")[0]))
times["download_and_extract"] = times.get("download_and_extract", 0) + time.time() - t
try:
values = extract_braceexpand_values(src, shard)
max_values = extract_braceexpand_values(src, list(braceexpand.braceexpand(src))[-1])
for i, val in enumerate(values):
values[i] = val.zfill(len(max_values[i]))
write_shard_id = "".join(values)

# TODO: find better way of doing this earlier
if write_shard_id in done_shards:
continue

times = {}
t = time.time()
with tempfile.TemporaryDirectory(prefix=f"worker_{global_rank}_") as tempdir:
os.chmod(tempdir, 0o777) # This lets subprocesses from v2np read files in the tempdir
folder = "/".join(shard.split("/")[0:-1])
fs, output_path = fsspec.core.url_to_fs(folder)

read_shard_id = shard.split("/")[-1].split(".tar")[0]

tar_bytes = io.BytesIO(fs.open(f"{output_path}/{read_shard_id}.tar").read())
with tarfile.open(fileobj=tar_bytes) as tar:
tar.extractall(tempdir)
writer.create_shard(shard_id=write_shard_id)
times["download_and_extract"] = times.get("download_and_extract", 0) + time.time() - t
t = time.time()

vids, ids, meta = read_shard(tempdir, pass_through_keys=pass_through_keys)

meta_refs = list(range(len(vids)))
fr = FrameReader(
vids,
meta_refs,
take_every_nth=take_every_nth,
target_fps=target_fps,
resize_size=img_size,
workers=frame_workers,
memory_size=frame_memory_size,
)
fr.start_reading()

frames, ind_dict = [], {}
block_size = 0
i = 0
n_frames = 0
for vid_frames, info in fr:
i += 1

if captioning_strategy == "center":
vid_frames = vid_frames[len(vid_frames) // 2 : len(vid_frames) // 2 + 1]

n_frames += len(vid_frames)
frames.append(vid_frames)
ind_dict[info["reference"]] = (
block_size,
block_size + vid_frames.shape[0],
info["dst_name"],
vids, ids, meta = read_shard(tempdir, vid_ext, pass_through_keys=pass_through_keys)

meta_refs = list(range(len(vids)))
fr = FrameReader(
vids,
meta_refs,
take_every_nth=take_every_nth,
target_fps=target_fps,
resize_size=img_size,
workers=frame_workers,
memory_size=frame_memory_size,
)
block_size += vid_frames.shape[0]
times["read_frames"] = times.get("read_frames", 0) + time.time() - t
t = time.time()
fr.start_reading()

frames, ind_dict = [], {}
block_size = 0
i = 0
n_frames = 0
for vid_frames, info in fr:
i += 1

if captioning_strategy == "center":
vid_frames = vid_frames[len(vid_frames) // 2 : len(vid_frames) // 2 + 1]

n_frames += len(vid_frames)
frames.append(vid_frames)
ind_dict[info["reference"]] = (
block_size,
block_size + vid_frames.shape[0],
info["dst_name"],
)
block_size += vid_frames.shape[0]
times["read_frames"] = times.get("read_frames", 0) + time.time() - t
t = time.time()

if i % CHUNK_SIZE == 0:
if i % CHUNK_SIZE == 0:
encode_chunk(
frames,
ind_dict,
writer, # TODO: turn all args below this into kwarg dict and just unpack
fm,
meta,
ids,
use_dst_name,
device,
input_format=input_format,
captioning_strategy=captioning_strategy,
frame_tokenization_strategy=frame_tokenization_strategy,
generated_caption_key=generated_caption_key,
)
times["encode"] = times.get("encode", 0) + time.time() - t
t = time.time()
frames, ind_dict, block_size = [], {}, 0
t = time.time()
if len(frames) > 0: # TODO: make this cleaner
encode_chunk(
frames,
ind_dict,
writer, # TODO: turn all args below this into kwarg dict and just unpack
writer,
fm,
meta,
ids,
Expand All @@ -266,31 +318,12 @@ def clip_video_encode(
frame_tokenization_strategy=frame_tokenization_strategy,
generated_caption_key=generated_caption_key,
)
times["encode"] = times.get("encode", 0) + time.time() - t
t = time.time()
frames, ind_dict, block_size = [], {}, 0
t = time.time()
if len(frames) > 0: # TODO: make this cleaner
encode_chunk(
frames,
ind_dict,
writer,
fm,
meta,
ids,
use_dst_name,
device,
input_format=input_format,
captioning_strategy=captioning_strategy,
frame_tokenization_strategy=frame_tokenization_strategy,
generated_caption_key=generated_caption_key,
)
times["encode"] = times.get("encode", 0) + time.time() - t
t = time.time()
frame_adjusted = {k: n_frames / v for k, v in times.items()}
print(f"Frames/s: {frame_adjusted}")
# except Exception as e: # pylint: disable=(broad-except)
# print(f"Shard {shard} failed: {str(e)}")
times["encode"] = times.get("encode", 0) + time.time() - t
t = time.time()
frame_adjusted = {k: n_frames / v for k, v in times.items()}
print(f"Frames/s: {frame_adjusted}")
except Exception as e: # pylint: disable=(broad-except)
print(f"Shard {shard} failed: {str(e)}")


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions clip_video_encode/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_data(self):
return vids, ids, meta


def read_shard(tempdir, pass_through_keys=None):
def read_shard(tempdir, vid_ext="mp4", pass_through_keys=None):
"""
Extracts shard a tempdir and returns references to files inside
Expand All @@ -86,15 +86,15 @@ def read_shard(tempdir, pass_through_keys=None):
pass_through_keys = []

vids = sorted(
[f.split("/")[-1] for f in glob.glob(os.path.join(tempdir, "*.mp4"))]
[f.split("/")[-1] for f in glob.glob(os.path.join(tempdir, f"*.{vid_ext}"))]
) # TODO: parameterize the video extension

read_funcs = {
"json": lambda path: json.load(open(path, "rb")), # pylint: disable=consider-using-with
"txt": lambda path: open(path, "r", encoding="UTF-8").read(), # pylint: disable=consider-using-with
}

keys, meta = [x.split(".mp4")[0] for x in vids], []
keys, meta = [x.split(f".{vid_ext}")[0] for x in vids], []
for key in keys:
metadata = {}

Expand All @@ -107,7 +107,7 @@ def read_shard(tempdir, pass_through_keys=None):
if ext in read_funcs:
read_data = read_funcs[ext](file_path)
else:
read_data = open(path, "rb").read() # pylint: disable=consider-using-with
read_data = open(file_path, "rb").read() # pylint: disable=consider-using-with
metadata[ext] = read_data

meta.append(metadata)
Expand Down
9 changes: 6 additions & 3 deletions clip_video_encode/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,12 @@ def create_shard(self, shard_id=None):
self.close()
if shard_id is not None:
self.shard_id = shard_id
shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string
shard_id=self.shard_id, oom_shard_count=self.oom_shard_count
)

shard_name = shard_id
if not isinstance(shard_id, str):
shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string
shard_id=self.shard_id, oom_shard_count=self.oom_shard_count
)
shard_name += "_" + self.shard_suffix

fs, output_path = fsspec.core.url_to_fs(self.output_folder)
Expand Down

0 comments on commit b5df2c7

Please sign in to comment.