Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Put values, lengths and offsets of NJTs together in storage #1023

Open
wants to merge 10 commits into
base: gh/vmoens/24/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import copyreg
import queue
from multiprocessing.reduction import ForkingPickler

import torch
Expand All @@ -21,7 +22,17 @@


def _rebuild_tensordict_files(flat_key_values, metadata_dict, is_shared: bool = False):
_nt_values_and_keys = queue.Queue()
_nt_lengths = queue.Queue()
_nt_offsets = queue.Queue()

def from_metadata(metadata=metadata_dict, prefix=None):
metadata = dict(metadata)

_ = metadata.pop("njt_values_start", None)
_ = metadata.pop("njt_lengths_start", None)
_ = metadata.pop("njt_offsets_start", None)

non_tensor = metadata.pop("non_tensors")
leaves = metadata.pop("leaves")
cls = metadata.pop("cls")
Expand All @@ -36,22 +47,21 @@ def from_metadata(metadata=metadata_dict, prefix=None):
total_key = (key,) if prefix is None else prefix + (key,)
if total_key[-1].startswith("<NJT>"):
nested_values = flat_key_values[total_key]
nested_lengths = None
total_key = total_key[:-1] + total_key[-1].replace("<NJT>", "")
_nt_values_and_keys.put((nested_values, total_key))
continue
if total_key[-1].startswith("<NJT_LENGTHS>"):
nested_lengths = flat_key_values[total_key]
_nt_lengths.put(nested_lengths)
continue
elif total_key[-1].startswith("<NJT_OFFSETS"):
offsets = flat_key_values[total_key]
key = key.replace("<NJT_OFFSETS>", "")
value = torch.nested.nested_tensor_from_jagged(
nested_values, offsets=offsets, lengths=nested_lengths
)
del nested_values
del nested_lengths
_nt_offsets.put(offsets)
continue
else:
value = flat_key_values[total_key]
d[key] = value

for k, v in metadata.items():
# Each remaining key is a tuple pointing to a sub-tensordict
d[k] = from_metadata(
Expand All @@ -64,7 +74,18 @@ def from_metadata(metadata=metadata_dict, prefix=None):
# result._is_shared = is_shared
return result

return from_metadata()
result = from_metadata()
# Then assign the nested tensors
while not _nt_values_and_keys.empty():
vals, key = _nt_values_and_keys.get()
lengths = _nt_lengths.get()
offsets = _nt_offsets.get()
value = torch.nested.nested_tensor_from_jagged(
vals, offsets=offsets, lengths=lengths
)
result._set_tuple(key, value, inplace=False, validated=True)

return result


def _rebuild_tensordict_files_shared(flat_key_values, metadata_dict):
Expand All @@ -75,9 +96,18 @@ def _rebuild_tensordict_files_consolidated(
metadata,
storage,
):
_nt_values_and_keys = queue.Queue()
_nt_lengths = queue.Queue()
_nt_offsets = queue.Queue()

def from_metadata(metadata=metadata, prefix=None):
consolidated = {"storage": storage, "metadata": metadata}
metadata = dict(metadata)

_ = metadata.pop("njt_values_start", None)
_ = metadata.pop("njt_lengths_start", None)
_ = metadata.pop("njt_offsets_start", None)

non_tensor = metadata.pop("non_tensors")
leaves = metadata.pop("leaves")
cls = metadata.pop("cls")
Expand All @@ -101,33 +131,45 @@ def from_metadata(metadata=metadata, prefix=None):
if key.startswith("<NJT>"):
raise RuntimeError
elif key.startswith("<NJT_VALUES>"):
nested_values = value
nested_lengths = None
key = key.replace("<NJT_VALUES>", "")
if prefix:
total_key = prefix + (key,)
else:
total_key = (key,)
_nt_values_and_keys.put((value, total_key))
continue
elif key.startswith("<NJT_LENGTHS>"):
nested_lengths = value
_nt_lengths.put(value)
continue
elif key.startswith("<NJT_OFFSETS>"):
from torch.nested._internal.nested_tensor import NestedTensor

offsets = value
value = NestedTensor(
nested_values, offsets=offsets, lengths=nested_lengths
)
key = key.replace("<NJT_OFFSETS>", "")
_nt_offsets.put(value)
if _nt_offsets.qsize() > _nt_lengths.qsize():
_nt_lengths.put(None)
continue
d[key] = value
for k, v in metadata.items():
for key, val in metadata.items():
# Each remaining key is a tuple pointing to a sub-tensordict
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
d[key] = from_metadata(
val, prefix=prefix + (key,) if prefix is not None else (key,)
)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if is_locked:
result = result.lock_()
result._consolidated = consolidated
return result

return from_metadata()
result = from_metadata()
# Then assign the nested tensors
while not _nt_values_and_keys.empty():
vals, key = _nt_values_and_keys.get()
lengths = _nt_lengths.get()
offsets = _nt_offsets.get()
value = torch.nested.nested_tensor_from_jagged(
vals, offsets=offsets, lengths=lengths
)
result._set_tuple(key, value, inplace=False, validated=True)

return result


def _make_td(cls, state):
Expand Down
Loading
Loading