Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

read_shards: generalize and improve #83

Merged
merged 9 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion clip_video_encode/clip_video_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def clip_video_encode(
writer.create_shard(shard_id=int(shard_id.split(".tar")[0]))
times["download_and_extract"] = times.get("download_and_extract", 0) + time.time() - t
t = time.time()

vids, ids, meta = read_shard(tempdir, pass_through_keys=pass_through_keys)

meta_refs = list(range(len(vids)))
fr = FrameReader(
vids,
Expand Down Expand Up @@ -253,7 +255,7 @@ def clip_video_encode(
encode_chunk(
frames,
ind_dict,
writer,
writer, # TODO: turn all args below this into kwarg dict and just unpack
fm,
meta,
ids,
Expand Down
22 changes: 14 additions & 8 deletions clip_video_encode/handle_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ def encode_chunk(
vid_id = dst_name[:-4] if use_dst_name else ids[ref]
if input_format == "webdataset":
vid_meta = meta[ref]
vid_meta["json"] = vid_meta["json"] if "json" in vid_meta else {}
else:
vid_meta = {}
vid_meta = {"json": {}}
for k in meta:
vid_meta[k] = meta[k][ref].as_py()
vid_meta["json"][k] = meta[k][ref].as_py()

# NOTE: Warning this might overwrite previous caption
# NOTE: for now assumes there is only one caption
vid_meta[generated_caption_key] = captions[i0:it][0]
vid_meta["json"][generated_caption_key] = captions[i0:it][0]

# TODO: we should be able to do both at once with a CoCa model
writer.write(None, vid_id, vid_meta)
Expand All @@ -63,9 +64,11 @@ def encode_chunk(
if input_format == "webdataset":
vid_meta = meta[ref]
else:
vid_meta = {}
vid_meta = {"json": {}}
for k in meta:
vid_meta[k] = meta[k][ref].as_py()
vid_meta["json"][k] = meta[k][ref].as_py()
if "caption" in vid_meta["json"]:
vid_meta["txt"] = vid_meta["json"]["caption"]

video_tokens = tokens[i0:it]
writer.write(video_tokens, vid_id, vid_meta)
Expand All @@ -90,9 +93,11 @@ def encode_chunk(
if input_format == "webdataset":
vid_meta = meta[ref]
else:
vid_meta = {}
vid_meta = {"json": {}}
for k in meta:
vid_meta[k] = meta[k][ref].as_py()
vid_meta["json"][k] = meta[k][ref].as_py()
if "caption" in vid_meta["json"]:
vid_meta["txt"] = vid_meta["json"]["caption"]

frame_embeddings = embeddings[i0:it]
if caption_embs is not None:
Expand All @@ -102,6 +107,7 @@ def encode_chunk(

sim = (fe @ ce.T).tolist()

vid_meta["clip_frame_similarity"] = sim
vid_meta["json"] = vid_meta["json"] if "json" in vid_meta else {}
vid_meta["json"]["clip_frame_similarity"] = sim

writer.write(frame_embeddings, vid_id, vid_meta)
47 changes: 24 additions & 23 deletions clip_video_encode/reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""handles input parsing."""
import os
import json
import glob

Expand Down Expand Up @@ -71,45 +72,45 @@ def get_data(self):
return vids, ids, meta


# TODO: hard refactor
def read_shard(tempdir, pass_through_keys=None):
"""
Extract video filepaths, video ids, and metadata from the contents of an opened WebDataset shard
Extracts shard a tempdir and returns references to files inside

Input:
tempdir:
path to directory containing contents of an opened WebDataset shard with input data
pass_through_keys:
extensions we would like to keep from the source shard in the output shard
"""
if pass_through_keys is None:
pass_through_keys = []

vids = sorted(
[f.split("/")[-1] for f in glob.glob(tempdir + "/" + "*.mp4")]
[f.split("/")[-1] for f in glob.glob(os.path.join(tempdir, "*.mp4"))]
) # TODO: parameterize the video extension

has_txt = len(glob.glob(tempdir + "/" + "*.txt")) > 0
has_json = len(glob.glob(tempdir + "/" + "*.json")) > 0
read_funcs = {
"json": lambda path: json.load(open(path, "rb")), # pylint: disable=consider-using-with
"txt": lambda path: open(path, "r", encoding="UTF-8").read(), # pylint: disable=consider-using-with
}

keys = [x.split(".mp4")[0] for x in vids]
meta = []
keys, meta = [x.split(".mp4")[0] for x in vids], []
for key in keys:
if has_json and "json" in pass_through_keys:
with open(tempdir + "/" + key + ".json", "rb") as f:
metadata = json.load(f)
else:
metadata = {}

if has_txt and "txt" in pass_through_keys:
with open(tempdir + "/" + key + ".txt", "r", encoding="UTF-8") as f:
txt = f.read()
metadata["caption"] = txt

if "mp4" in pass_through_keys:
with open(tempdir + "/" + key + ".mp4", "rb") as f:
mp4_video = f.read()
metadata["mp4_video"] = mp4_video
metadata = {}

# handles double extensions for weird metadata types f.e. ".optical-flow.npy" vs. ".clip_b.npy"
exts = [".".join(f.split(".")[1:]) for f in glob.glob(os.path.join(tempdir, f"{key}.*"))]
desired_exts = list(set(pass_through_keys).intersection(set(exts)))

for ext in desired_exts:
file_path = os.path.join(tempdir, f"{key}.{ext}")
if ext in read_funcs:
read_data = read_funcs[ext](file_path)
else:
read_data = open(path, "rb").read() # pylint: disable=consider-using-with
metadata[ext] = read_data

meta.append(metadata)

vids = [tempdir + "/" + v for v in vids]
vids = [os.path.join(tempdir, v) for v in vids]
return vids, keys, meta
42 changes: 16 additions & 26 deletions clip_video_encode/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
from io import BytesIO


write_fmt = {
"mp4": lambda data: data, # pylint: disable=unnecessary-lambda
"txt": lambda data: str(data), # pylint: disable=unnecessary-lambda
"json": lambda data: json.dumps(data, indent=4),
}


class FileWriter:
"""Writes output as files."""

Expand All @@ -19,30 +26,19 @@ def __init__(self, output_folder):

def write(self, arr, key, metadata=None):
"""write sample to file."""
key = str(key)
key, metadata = str(key), {} if metadata is None else metadata

save_pth = os.path.join(self.output_folder, key + ".npy")
with self.fs.open(save_pth, "wb") as f:
nbp = BytesIO()
np.save(nbp, arr)
f.write(nbp.getbuffer())

if metadata is not None:
if "caption" in metadata:
caption = str(metadata.pop("caption"))
caption_filename = os.path.join(self.output_folder, key + ".txt")
with self.fs.open(caption_filename, "w") as f:
f.write(caption)
if len(metadata) > 0:
if "mp4_video" in metadata:
vid_bytes = metadata.pop("mp4_video")
vid_filename = os.path.join(self.output_folder, key + ".mp4")
with self.fs.open(vid_filename, "w") as f:
f.write(vid_bytes)
j = json.dumps(metadata, indent=4)
meta_filename = os.path.join(self.output_folder, key + ".json")
with self.fs.open(meta_filename, "w") as f:
f.write(j)
for ext in metadata:
md_filename = os.path.join(self.output_folder, f"{key}.{ext}")
write_data = write_fmt[ext](metadata[ext]) if ext in write_fmt else metadata[ext]
with self.fs.open(md_filename, "w") as f:
f.write(write_data)

def close(self):
pass
Expand Down Expand Up @@ -82,7 +78,7 @@ def create_shard(self, shard_id=None):

def write(self, arr, key, metadata=None):
"""write sample to current shard."""
key = str(key)
key, metadata = str(key), {} if metadata is None else metadata
if self.count >= self.maxcount:
self.shard_id += 1
self.count = 0
Expand All @@ -92,14 +88,8 @@ def write(self, arr, key, metadata=None):
if arr is not None:
sample[self.encode_format] = arr

if metadata is not None:
if "caption" in metadata:
sample["txt"] = str(metadata.pop("caption"))
if len(metadata) > 0:
if "mp4_video" in metadata:
vid_bytes = metadata.pop("mp4_video")
sample["mp4"] = vid_bytes
sample["json"] = json.dumps(metadata, indent=4)
for ext in metadata:
sample[ext] = write_fmt[ext](metadata[ext]) if ext in write_fmt else metadata[ext]

self.tarwriter.write(sample)
self.count += 1
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_writer(writer_type):
vid_embeds = [np.ones((N_FRAMES, lat_dim), dtype=float) * i for i in range(N_VIDS)]

for i, emb in enumerate(vid_embeds):
fake_metadata = {"caption": str(i), "x": i}
fake_metadata = {"json": {"caption": str(i), "x": i}, "txt": str(i)}
writer.write(emb, str(i), fake_metadata)
writer.close()

Expand Down
Loading