From 29fbfeef624ed299a4642475973c36ec920b87fa Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 14 Oct 2024 18:41:48 +0100 Subject: [PATCH 01/37] Update [ghstack-poisoned] --- tensordict/base.py | 73 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index 1795504b6..4e9ca37e9 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10456,17 +10456,88 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking): if pin_memory: storage = storage.pin_memory() storage_cast = storage.to(device, non_blocking=True) + if is_dynamo_compiling(): + return self._to_reconstruct_compiled( + storage, storage_cast, device, num_threads, non_blocking + ) + return self._to_reconstruct( + storage, storage_cast, device, num_threads, non_blocking + ) + + def _to_reconstruct(self, storage, storage_cast, device, num_threads, non_blocking): untyped_storage = storage_cast.untyped_storage() def set_(x): + if x.is_nested: + if x.layout != torch.jagged: + raise RuntimeError( + "to(device) with nested tensors that do not have a jagged layout is not implemented yet. " + "Please raise an issue on GitHub." + ) + values = x._values + lengths = x._lengths + offsets = x._offsets + return torch.nested.nested_tensor_from_jagged( + set_(values), + offsets=set_(offsets), + lengths=set_(lengths) if lengths is not None else None, + ) storage_offset = x.storage_offset() stride = x.stride() - return torch.empty_like(x, device=device).set_( + return x.new_empty((0,), device=device).set_( untyped_storage, size=x.shape, stride=stride, storage_offset=storage_offset, ) + # return torch.empty_like(x, device=device).set_( + # untyped_storage, + # size=x.shape, + # stride=stride, + # storage_offset=storage_offset, + # ) + + result = self._fast_apply( + set_, device=torch.device(device), num_threads=num_threads + ) + result._consolidated = {"storage": storage_cast} + if "metadata" in self._consolidated: + result._consolidated["metadata"] = deepcopy(self._consolidated["metadata"]) + if non_blocking in (False, None): + if device.type == "cuda" and non_blocking is False: + # sending to CUDA force sync + cuda_device = device + elif storage.device.type == "cuda": + # sending from cuda: need sync unless intentionally not asked for + cuda_device = storage.device.type + else: + cuda_device = None + if cuda_device is not None: + torch.cuda.current_stream(cuda_device).synchronize() + + return result + + def _to_reconstruct_compiled(self, storage, storage_cast, device, num_threads, non_blocking): + def set_(x): + if x.is_nested: + if x.layout != torch.jagged: + raise RuntimeError( + "to(device) with nested tensors that do not have a jagged layout is not implemented yet. " + "Please raise an issue on GitHub." + ) + values = x._values + lengths = x._lengths + offsets = x._offsets + return torch._nested_view_from_jagged( + set_(values), + set_(offsets), + x, + lengths=set_(lengths) if lengths is not None else None, + ) + storage_offset = x.storage_offset() + stride = x.stride() + index_slice = slice(storage_offset, storage_offset + x.numel(), stride[0]) + return storage_cast.view(x.dtype)[index_slice].view(x.type) result = self._fast_apply( set_, device=torch.device(device), num_threads=num_threads From ae18ecd55a4ede63d2559c34cf947890516e4392 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 16 Oct 2024 08:27:03 +0100 Subject: [PATCH 02/37] Update tensordict/base.py Co-authored-by: Shagun Sodhani <1321193+shagunsodhani@users.noreply.github.com> --- tensordict/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index 4e9ca37e9..4463e83e9 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10472,7 +10472,7 @@ def set_(x): if x.layout != torch.jagged: raise RuntimeError( "to(device) with nested tensors that do not have a jagged layout is not implemented yet. " - "Please raise an issue on GitHub." + "Please raise an issue on GitHub: https://github.com/pytorch/tensordict/issues" ) values = x._values lengths = x._lengths From 775771e4ce1fc716d62cf7e6f2a2850c9d6242b2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 16 Oct 2024 13:30:56 +0100 Subject: [PATCH 03/37] Update [ghstack-poisoned] --- tensordict/base.py | 72 +++++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 00d36c0bc..fd5d1a7f4 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10394,11 +10394,12 @@ def to(self, *args, **kwargs) -> T: return result if self.is_consolidated() and dtype is None: - return self._to_consolidated_compile( + return self._to_consolidated( device=device, pin_memory=non_blocking_pin, num_threads=num_threads, non_blocking=non_blocking, + compilable=is_dynamo_compiling(), ) if non_blocking is None: @@ -10456,14 +10457,42 @@ def to_pinmem(tensor, _to=to): self._sync_all() return result - def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking): + def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking, compilable): if num_threads is None: # unspecified num_threads should mean 0 num_threads = 0 + storage = self._consolidated["storage"] - if pin_memory: - storage = storage.pin_memory() - storage_cast = storage.to(device, non_blocking=True) + + @torch.compiler.disable() + def to(storage): + if pin_memory: + storage = storage.pin_memory() + storage_cast = storage.to(device, non_blocking=True) + return storage_cast + storage_cast = to(storage) + + if compilable: + result = self._to_consolidated_compile(device=device, num_threads=num_threads, storage_cast=storage_cast) + else: + result = self._to_consolidated_eager(device=device, num_threads=num_threads, storage_cast=storage_cast) + + if non_blocking in (False, None): + if device.type == "cuda" and non_blocking is False: + # sending to CUDA force sync + cuda_device = device + elif storage.device.type == "cuda": + # sending from cuda: need sync unless intentionally not asked for + cuda_device = storage.device.type + else: + cuda_device = None + if cuda_device is not None: + torch.cuda.current_stream(cuda_device).synchronize() + + return result + + def _to_consolidated_eager(self, *, device, num_threads, storage_cast): + untyped_storage = storage_cast.untyped_storage() def set_(x): @@ -10532,21 +10561,10 @@ def copy_dict(d): } result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) - if non_blocking in (False, None): - if device.type == "cuda" and non_blocking is False: - # sending to CUDA force sync - cuda_device = device - elif storage.device.type == "cuda": - # sending from cuda: need sync unless intentionally not asked for - cuda_device = storage.device.type - else: - cuda_device = None - if cuda_device is not None: - torch.cuda.current_stream(cuda_device).synchronize() - return result - def _to_consolidated_compile(self, *, device, pin_memory, num_threads, non_blocking): + @torch.compile(dynamic=True) + def _to_consolidated_compile(self, *, device, num_threads, storage_cast): def get_l(metadata, lengths=None, pos=None, keys=None, prefix=()): root = False @@ -10579,10 +10597,6 @@ def split_storage(consolidated): if num_threads is None: # unspecified num_threads should mean 0 num_threads = 0 - storage = self._consolidated["storage"] - if pin_memory: - storage = storage.pin_memory() - storage_cast = storage.to(device, non_blocking=True) _consolidated = {"storage": storage_cast} if "metadata" in self._consolidated: @@ -10649,21 +10663,7 @@ def set_(name, x): set_, device=torch.device(device), num_threads=num_threads, named=True, nested_keys=True, ) result._consolidated = _consolidated - - if non_blocking in (False, None): - if device.type == "cuda" and non_blocking is False: - # sending to CUDA force sync - cuda_device = device - elif storage.device.type == "cuda": - # sending from cuda: need sync unless intentionally not asked for - cuda_device = storage.device.type - else: - cuda_device = None - if cuda_device is not None: - torch.cuda.current_stream(cuda_device).synchronize() - return result - def _sync_all(self): if _has_cuda: # TODO: dynamo doesn't like torch.cuda.is_initialized From a0a49f53787a06a6f39ae5b70700e26e449fd862 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 16 Oct 2024 13:42:26 +0100 Subject: [PATCH 04/37] Update [ghstack-poisoned] --- tensordict/base.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index fd5d1a7f4..c66ff0d04 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10563,10 +10563,9 @@ def copy_dict(d): result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) return result - @torch.compile(dynamic=True) def _to_consolidated_compile(self, *, device, num_threads, storage_cast): - def get_l(metadata, lengths=None, pos=None, keys=None, prefix=()): + def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()): root = False if lengths is None: lengths = [] @@ -10579,7 +10578,7 @@ def get_l(metadata, lengths=None, pos=None, keys=None, prefix=()): keys.append(prefix + (k,)) for k, d in metadata.items(): if "leaves" in d: - get_l(d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,)) + get_tensors_length(d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,)) if root: # l = torch.empty(len(lengths), dtype=torch.long) # l[torch.as_tensor(pos)] = torch.as_tensor(lengths) @@ -10591,7 +10590,7 @@ def get_l(metadata, lengths=None, pos=None, keys=None, prefix=()): return out0, out1 def split_storage(consolidated): - keys, splits = get_l(consolidated["metadata"]) + keys, splits = get_tensors_length(consolidated["metadata"]) return dict(zip(keys, consolidated["storage"].split(splits))) if num_threads is None: @@ -10632,9 +10631,11 @@ def set_(name, x): values = x._values lengths = x._lengths offsets = x._offsets - kwargs["offsets"] = slice_map[(*name[:-1], ""+name[-1],)].view(offsets.dtype).view(offsets.shape) + storage_offsets = slice_map[(*name[:-1], ""+name[-1],)] + kwargs["offsets"] = storage_offsets.view(offsets.dtype).view(offsets.shape) if lengths is not None: - kwargs["lengths"] = slice_map[(*name[:-1], ""+name[-1],)].view(lengths.dtype).view(lengths.shape) + storage_lengths = slice_map[(*name[:-1], ""+name[-1],)] + kwargs["lengths"] = storage_lengths.view(lengths.dtype).view(lengths.shape) ragged_source = lengths else: ragged_source = offsets @@ -10653,8 +10654,9 @@ def set_(name, x): ragged_source ] + storage_values = slice_map[(*name[:-1], ""+name[-1],)] return NestedTensor( - slice_map[(*name[:-1], ""+name[-1],)].view(values.dtype).view(values.shape), + storage_values.view(values.dtype).view(values.shape), **kwargs, ) return slice_map[name].view(x.dtype).view(x.shape) From 819e5d92b3be7dd4d66654cd7b075eca6edbad8d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 16 Oct 2024 13:55:43 +0100 Subject: [PATCH 05/37] Update [ghstack-poisoned] --- tensordict/base.py | 64 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 14 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index c66ff0d04..8aa862c3e 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10457,7 +10457,9 @@ def to_pinmem(tensor, _to=to): self._sync_all() return result - def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking, compilable): + def _to_consolidated( + self, *, device, pin_memory, num_threads, non_blocking, compilable + ): if num_threads is None: # unspecified num_threads should mean 0 num_threads = 0 @@ -10470,12 +10472,17 @@ def to(storage): storage = storage.pin_memory() storage_cast = storage.to(device, non_blocking=True) return storage_cast + storage_cast = to(storage) if compilable: - result = self._to_consolidated_compile(device=device, num_threads=num_threads, storage_cast=storage_cast) + result = self._to_consolidated_compile( + device=device, num_threads=num_threads, storage_cast=storage_cast + ) else: - result = self._to_consolidated_eager(device=device, num_threads=num_threads, storage_cast=storage_cast) + result = self._to_consolidated_eager( + device=device, num_threads=num_threads, storage_cast=storage_cast + ) if non_blocking in (False, None): if device.type == "cuda" and non_blocking is False: @@ -10578,12 +10585,18 @@ def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()): keys.append(prefix + (k,)) for k, d in metadata.items(): if "leaves" in d: - get_tensors_length(d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,)) + get_tensors_length( + d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,) + ) if root: # l = torch.empty(len(lengths), dtype=torch.long) # l[torch.as_tensor(pos)] = torch.as_tensor(lengths) - out0 = [None, ] * len(pos) - out1 = [None, ] * len(pos) + out0 = [ + None, + ] * len(pos) + out1 = [ + None, + ] * len(pos) for p, l, k in zip(pos, lengths, keys): out0[p] = k out1[p] = l @@ -10610,6 +10623,9 @@ def copy_dict(d): slice_map = split_storage(_consolidated) + def view_as(src, dest): + return src.view(dest.dtype)[: dest.numel()].view(dest.shape) + def set_(name, x): if not isinstance(name, tuple): name = (name,) @@ -10631,11 +10647,21 @@ def set_(name, x): values = x._values lengths = x._lengths offsets = x._offsets - storage_offsets = slice_map[(*name[:-1], ""+name[-1],)] - kwargs["offsets"] = storage_offsets.view(offsets.dtype).view(offsets.shape) + storage_offsets = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] + kwargs["offsets"] = view_as(storage_offsets, offsets) if lengths is not None: - storage_lengths = slice_map[(*name[:-1], ""+name[-1],)] - kwargs["lengths"] = storage_lengths.view(lengths.dtype).view(lengths.shape) + storage_lengths = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] + kwargs["lengths"] = view_as(storage_lengths, lengths) ragged_source = lengths else: ragged_source = offsets @@ -10654,18 +10680,28 @@ def set_(name, x): ragged_source ] - storage_values = slice_map[(*name[:-1], ""+name[-1],)] + storage_values = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] return NestedTensor( - storage_values.view(values.dtype).view(values.shape), + view_as(storage_values, values), **kwargs, ) - return slice_map[name].view(x.dtype).view(x.shape) + return view_as(slice_map[name], x) result = self._fast_apply( - set_, device=torch.device(device), num_threads=num_threads, named=True, nested_keys=True, + set_, + device=torch.device(device), + num_threads=num_threads, + named=True, + nested_keys=True, ) result._consolidated = _consolidated return result + def _sync_all(self): if _has_cuda: # TODO: dynamo doesn't like torch.cuda.is_initialized From 7fd7b1c99704b82db998fa91b40467903e21827c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 08:04:02 +0100 Subject: [PATCH 06/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 19 ++++++++++++++++--- benchmarks/compile/compile_td_test.py | 7 +++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 7014c097c..a22d649d6 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -14,6 +14,13 @@ TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) +@pytest.fixture(autouse=True, scope="module") +def empty_compiler_cache(): + torch._dynamo.reset_code_caches() + print("Emptying cache") + yield + + @pytest.fixture def td(): return TensorDict( @@ -52,7 +59,9 @@ def default_device(): pytest.skip("CUDA/MPS is not available") -@pytest.mark.parametrize("consolidated,compiled", [[False,False], [True,False],[True,True]]) +@pytest.mark.parametrize( + "consolidated,compiled", [[False, False], [True, False], [True, True]] +) @pytest.mark.skipif( TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" ) @@ -60,23 +69,27 @@ class TestTo: def test_to(self, benchmark, consolidated, td, default_device, compiled): if consolidated: td = td.consolidate() + def to(td): return td.to(default_device) if compiled: to = torch.compile(to) - + for _ in range(3): + to(td) benchmark(to, td) def test_to_njt(self, benchmark, consolidated, njt_td, default_device, compiled): if consolidated: njt_td = njt_td.consolidate() + def to(td): return td.to(default_device) if compiled: to = torch.compile(to) - + for _ in range(3): + to(njt_td) benchmark(to, njt_td) diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index 3a1ef0ee1..7d877efb7 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -23,6 +23,13 @@ class MyTensorClass: f: torch.Tensor +@pytest.fixture(autouse=True, scope="module") +def empty_compiler_cache(): + torch._dynamo.reset_code_caches() + print("Emptying cache") + yield + + # Functions def add_one(td): return td + 1 From 1ca6fe7dfcf598e90e7c6a39f360de50416c439d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 08:04:52 +0100 Subject: [PATCH 07/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 2 +- benchmarks/compile/compile_td_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index a22d649d6..9db87297d 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -14,7 +14,7 @@ TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture(autouse=True, scope="function") def empty_compiler_cache(): torch._dynamo.reset_code_caches() print("Emptying cache") diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index 7d877efb7..b87df1918 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -23,7 +23,7 @@ class MyTensorClass: f: torch.Tensor -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture(autouse=True, scope="function") def empty_compiler_cache(): torch._dynamo.reset_code_caches() print("Emptying cache") From e3ab749e09b347326316244c61d5723df11681ea Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 08:05:49 +0100 Subject: [PATCH 08/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 1 - benchmarks/compile/compile_td_test.py | 1 - 2 files changed, 2 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 9db87297d..546c30d85 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -17,7 +17,6 @@ @pytest.fixture(autouse=True, scope="function") def empty_compiler_cache(): torch._dynamo.reset_code_caches() - print("Emptying cache") yield diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index b87df1918..c07859490 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -26,7 +26,6 @@ class MyTensorClass: @pytest.fixture(autouse=True, scope="function") def empty_compiler_cache(): torch._dynamo.reset_code_caches() - print("Emptying cache") yield From b620cdfdf021eacbf9ff1912984483040dec91de Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 08:13:53 +0100 Subject: [PATCH 09/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 4 ++++ tensordict/base.py | 11 ++--------- tensordict/utils.py | 8 ++++++++ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 546c30d85..00fb70370 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -74,8 +74,10 @@ def to(td): if compiled: to = torch.compile(to) + for _ in range(3): to(td) + benchmark(to, td) def test_to_njt(self, benchmark, consolidated, njt_td, default_device, compiled): @@ -87,8 +89,10 @@ def to(td): if compiled: to = torch.compile(to) + for _ in range(3): to(njt_td) + benchmark(to, njt_td) diff --git a/tensordict/base.py b/tensordict/base.py index a98428caf..8ac728d6a 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -72,7 +72,7 @@ _split_tensordict, _td_fields, _unravel_key_to_tuple, - _zip_strict, + _zip_strict,_to_escape_compile, cache, convert_ellipsis_to_idx, DeviceType, @@ -10512,14 +10512,7 @@ def _to_consolidated( storage = self._consolidated["storage"] - @torch.compiler.disable() - def to(storage): - if pin_memory: - storage = storage.pin_memory() - storage_cast = storage.to(device, non_blocking=True) - return storage_cast - - storage_cast = to(storage) + storage_cast = _to_escape_compile(storage) if compilable: result = self._to_consolidated_compile( diff --git a/tensordict/utils.py b/tensordict/utils.py index 280b224a0..5d14265a0 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2694,3 +2694,11 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths): values, **kwargs, ) + + +@torch.compiler.disable() +def _to_escape_compile(storage, device, pin_memory): + if pin_memory: + storage = storage.pin_memory() + storage_cast = storage.to(device, non_blocking=True) + return storage_cast From 433d960ca7bc900525a98e81d0540f4015bb5b95 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 08:14:38 +0100 Subject: [PATCH 10/37] Update [ghstack-poisoned] --- tensordict/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index 8ac728d6a..7891499c0 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10512,7 +10512,7 @@ def _to_consolidated( storage = self._consolidated["storage"] - storage_cast = _to_escape_compile(storage) + storage_cast = _to_escape_compile(storage, device=device, pin_memory=pin_memory) if compilable: result = self._to_consolidated_compile( From d9d4335251f913335e5c1df272d64860347b7ce9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 08:33:26 +0100 Subject: [PATCH 11/37] Update [ghstack-poisoned] --- tensordict/base.py | 93 ++++++++++++++++++++++++++++++---------------- 1 file changed, 62 insertions(+), 31 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 7891499c0..8de78258e 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10669,20 +10669,67 @@ def set_(name, x): if not isinstance(name, tuple): name = (name,) if x.is_nested: - from torch._subclasses.fake_tensor import FakeTensor - from torch._subclasses.functional_tensor import FunctionalTensor - from torch.nested._internal.nested_tensor import ( - _tensor_symint_registry, - NestedTensor, - ) - from torch.nested._internal.ops import extract_kwargs - if x.layout != torch.jagged: raise RuntimeError( "to(device) with nested tensors that do not have a jagged layout is not implemented yet. " "Please raise an issue on GitHub." ) - kwargs = extract_kwargs(x) + # from torch._subclasses.fake_tensor import FakeTensor + # from torch._subclasses.functional_tensor import FunctionalTensor + # from torch.nested._internal.nested_tensor import ( + # _tensor_symint_registry, + # NestedTensor, + # ) + # from torch.nested._internal.ops import extract_kwargs + # + # kwargs = extract_kwargs(x) + # values = x._values + # lengths = x._lengths + # offsets = x._offsets + # storage_offsets = slice_map[ + # ( + # *name[:-1], + # "" + name[-1], + # ) + # ] + # kwargs["offsets"] = view_as(storage_offsets, offsets) + # if lengths is not None: + # storage_lengths = slice_map[ + # ( + # *name[:-1], + # "" + name[-1], + # ) + # ] + # kwargs["lengths"] = view_as(storage_lengths, lengths) + # ragged_source = lengths + # else: + # ragged_source = offsets + # new_thing = kwargs.get("lengths", kwargs.get("offsets")) + # if isinstance(new_thing, (FakeTensor, FunctionalTensor)): + # from torch._subclasses.functional_tensor import ( + # mb_unwrap_functional_tensor, + # ) + # + # # Temporary hack until we have the union find + # tgt = mb_unwrap_functional_tensor(new_thing) + # src = mb_unwrap_functional_tensor(ragged_source) + # tgt.nested_int_memo = src.nested_int_memo + # else: + # _tensor_symint_registry[new_thing] = _tensor_symint_registry[ + # ragged_source + # ] + # + # storage_values = slice_map[ + # ( + # *name[:-1], + # "" + name[-1], + # ) + # ] + # return NestedTensor( + # view_as(storage_values, values), + # **kwargs, + # ) + from torch.nested import nested_tensor_from_jagged values = x._values lengths = x._lengths offsets = x._offsets @@ -10692,7 +10739,7 @@ def set_(name, x): "" + name[-1], ) ] - kwargs["offsets"] = view_as(storage_offsets, offsets) + offsets = view_as(storage_offsets, offsets) if lengths is not None: storage_lengths = slice_map[ ( @@ -10700,35 +10747,19 @@ def set_(name, x): "" + name[-1], ) ] - kwargs["lengths"] = view_as(storage_lengths, lengths) - ragged_source = lengths - else: - ragged_source = offsets - new_thing = kwargs.get("lengths", kwargs.get("offsets")) - if isinstance(new_thing, (FakeTensor, FunctionalTensor)): - from torch._subclasses.functional_tensor import ( - mb_unwrap_functional_tensor, - ) - - # Temporary hack until we have the union find - tgt = mb_unwrap_functional_tensor(new_thing) - src = mb_unwrap_functional_tensor(ragged_source) - tgt.nested_int_memo = src.nested_int_memo - else: - _tensor_symint_registry[new_thing] = _tensor_symint_registry[ - ragged_source - ] - + lengths = view_as(storage_lengths, lengths) storage_values = slice_map[ ( *name[:-1], "" + name[-1], ) ] - return NestedTensor( + return nested_tensor_from_jagged( view_as(storage_values, values), - **kwargs, + offsets=offsets, + lengths=lengths ) + return view_as(slice_map[name], x) result = self._fast_apply( From 2580f9b886f0b18c3b0a57c9b49ee88b60515bc9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 08:45:25 +0100 Subject: [PATCH 12/37] Update [ghstack-poisoned] --- tensordict/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensordict/utils.py b/tensordict/utils.py index 5d14265a0..ffa90f371 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2696,9 +2696,13 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths): ) -@torch.compiler.disable() +@torch.library.custom_op("tensordict::_to_escape_compile", mutates_args=()) def _to_escape_compile(storage, device, pin_memory): if pin_memory: storage = storage.pin_memory() storage_cast = storage.to(device, non_blocking=True) return storage_cast + +@_to_escape_compile.register_fake +def _(storage, device, pin_memory): + return torch.empty_like(storage, device=device) From 97425b47c7bb1288268bed433bd0b0dac3cd48a0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 08:46:04 +0100 Subject: [PATCH 13/37] Update [ghstack-poisoned] --- tensordict/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensordict/utils.py b/tensordict/utils.py index ffa90f371..26d485237 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2697,12 +2697,12 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths): @torch.library.custom_op("tensordict::_to_escape_compile", mutates_args=()) -def _to_escape_compile(storage, device, pin_memory): +def _to_escape_compile(storage: torch.Tensor, device: torch.device, pin_memory: bool): if pin_memory: storage = storage.pin_memory() storage_cast = storage.to(device, non_blocking=True) return storage_cast @_to_escape_compile.register_fake -def _(storage, device, pin_memory): +def _(storage: torch.Tensor, device: torch.device, pin_memory: bool): return torch.empty_like(storage, device=device) From 9132b449a8bf26e26f09727ccd53e3b3b1bef840 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 08:47:30 +0100 Subject: [PATCH 14/37] Update [ghstack-poisoned] --- tensordict/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensordict/utils.py b/tensordict/utils.py index 26d485237..4ee52810d 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2697,12 +2697,12 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths): @torch.library.custom_op("tensordict::_to_escape_compile", mutates_args=()) -def _to_escape_compile(storage: torch.Tensor, device: torch.device, pin_memory: bool): +def _to_escape_compile(storage: torch.Tensor, device: torch.device, pin_memory: bool) -> torch.Tensor: if pin_memory: storage = storage.pin_memory() storage_cast = storage.to(device, non_blocking=True) return storage_cast @_to_escape_compile.register_fake -def _(storage: torch.Tensor, device: torch.device, pin_memory: bool): +def _(storage: torch.Tensor, device: torch.device, pin_memory: bool) -> torch.Tensor: return torch.empty_like(storage, device=device) From 2e7365986921c18c8bcf3d8005c4541c23932ef6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 08:57:23 +0100 Subject: [PATCH 15/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 17 ++++++++++------- tensordict/base.py | 8 ++++---- tensordict/utils.py | 5 ++++- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 00fb70370..b980a2087 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -59,36 +59,39 @@ def default_device(): @pytest.mark.parametrize( - "consolidated,compiled", [[False, False], [True, False], [True, True]] + "consolidated,compile_mode", + [[False, False], [True, False], [True, "default"], [True, "reduce-overhead"]], ) @pytest.mark.skipif( TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" ) class TestTo: - def test_to(self, benchmark, consolidated, td, default_device, compiled): + def test_to(self, benchmark, consolidated, td, default_device, compile_mode): if consolidated: td = td.consolidate() def to(td): return td.to(default_device) - if compiled: - to = torch.compile(to) + if compile_mode: + to = torch.compile(to, mode=compile_mode) for _ in range(3): to(td) benchmark(to, td) - def test_to_njt(self, benchmark, consolidated, njt_td, default_device, compiled): + def test_to_njt( + self, benchmark, consolidated, njt_td, default_device, compile_mode + ): if consolidated: njt_td = njt_td.consolidate() def to(td): return td.to(default_device) - if compiled: - to = torch.compile(to) + if compile_mode: + to = torch.compile(to, mode=compile_mode) for _ in range(3): to(njt_td) diff --git a/tensordict/base.py b/tensordict/base.py index 8de78258e..baefdcaf4 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -71,8 +71,9 @@ _shape, _split_tensordict, _td_fields, + _to_escape_compile, _unravel_key_to_tuple, - _zip_strict,_to_escape_compile, + _zip_strict, cache, convert_ellipsis_to_idx, DeviceType, @@ -10730,6 +10731,7 @@ def set_(name, x): # **kwargs, # ) from torch.nested import nested_tensor_from_jagged + values = x._values lengths = x._lengths offsets = x._offsets @@ -10755,9 +10757,7 @@ def set_(name, x): ) ] return nested_tensor_from_jagged( - view_as(storage_values, values), - offsets=offsets, - lengths=lengths + view_as(storage_values, values), offsets=offsets, lengths=lengths ) return view_as(slice_map[name], x) diff --git a/tensordict/utils.py b/tensordict/utils.py index 4ee52810d..793a284ae 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2697,12 +2697,15 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths): @torch.library.custom_op("tensordict::_to_escape_compile", mutates_args=()) -def _to_escape_compile(storage: torch.Tensor, device: torch.device, pin_memory: bool) -> torch.Tensor: +def _to_escape_compile( + storage: torch.Tensor, device: torch.device, pin_memory: bool +) -> torch.Tensor: if pin_memory: storage = storage.pin_memory() storage_cast = storage.to(device, non_blocking=True) return storage_cast + @_to_escape_compile.register_fake def _(storage: torch.Tensor, device: torch.device, pin_memory: bool) -> torch.Tensor: return torch.empty_like(storage, device=device) From 933abb97c7708a7884c3262cd026af9f156c4efc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 09:13:50 +0100 Subject: [PATCH 16/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 37 ++++++++++++++--------- tensordict/base.py | 55 ----------------------------------- 2 files changed, 24 insertions(+), 68 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index b980a2087..54c3ed869 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -59,46 +59,57 @@ def default_device(): @pytest.mark.parametrize( - "consolidated,compile_mode", - [[False, False], [True, False], [True, "default"], [True, "reduce-overhead"]], + "consolidated,compile_mode,num_threads", + [ + [False, False, None], + [True, False, None], + [True, False, 4], + [True, False, 16], + [True, "default", 0], + ], ) @pytest.mark.skipif( TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" ) class TestTo: - def test_to(self, benchmark, consolidated, td, default_device, compile_mode): + def test_to( + self, benchmark, consolidated, td, default_device, compile_mode, num_threads + ): if consolidated: td = td.consolidate() - def to(td): - return td.to(default_device) + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) if compile_mode: to = torch.compile(to, mode=compile_mode) for _ in range(3): - to(td) + to(td, num_threads=num_threads) - benchmark(to, td) + benchmark(to, td, num_threads) def test_to_njt( - self, benchmark, consolidated, njt_td, default_device, compile_mode + self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads ): if consolidated: njt_td = njt_td.consolidate() - def to(td): - return td.to(default_device) + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) if compile_mode: to = torch.compile(to, mode=compile_mode) for _ in range(3): - to(njt_td) + to(njt_td, num_threads=num_threads) - benchmark(to, njt_td) + benchmark(to, njt_td, num_threads) if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) + pytest.main( + [__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"] + + unknown + ) diff --git a/tensordict/base.py b/tensordict/base.py index baefdcaf4..3e5bf5e9c 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10675,61 +10675,6 @@ def set_(name, x): "to(device) with nested tensors that do not have a jagged layout is not implemented yet. " "Please raise an issue on GitHub." ) - # from torch._subclasses.fake_tensor import FakeTensor - # from torch._subclasses.functional_tensor import FunctionalTensor - # from torch.nested._internal.nested_tensor import ( - # _tensor_symint_registry, - # NestedTensor, - # ) - # from torch.nested._internal.ops import extract_kwargs - # - # kwargs = extract_kwargs(x) - # values = x._values - # lengths = x._lengths - # offsets = x._offsets - # storage_offsets = slice_map[ - # ( - # *name[:-1], - # "" + name[-1], - # ) - # ] - # kwargs["offsets"] = view_as(storage_offsets, offsets) - # if lengths is not None: - # storage_lengths = slice_map[ - # ( - # *name[:-1], - # "" + name[-1], - # ) - # ] - # kwargs["lengths"] = view_as(storage_lengths, lengths) - # ragged_source = lengths - # else: - # ragged_source = offsets - # new_thing = kwargs.get("lengths", kwargs.get("offsets")) - # if isinstance(new_thing, (FakeTensor, FunctionalTensor)): - # from torch._subclasses.functional_tensor import ( - # mb_unwrap_functional_tensor, - # ) - # - # # Temporary hack until we have the union find - # tgt = mb_unwrap_functional_tensor(new_thing) - # src = mb_unwrap_functional_tensor(ragged_source) - # tgt.nested_int_memo = src.nested_int_memo - # else: - # _tensor_symint_registry[new_thing] = _tensor_symint_registry[ - # ragged_source - # ] - # - # storage_values = slice_map[ - # ( - # *name[:-1], - # "" + name[-1], - # ) - # ] - # return NestedTensor( - # view_as(storage_values, values), - # **kwargs, - # ) from torch.nested import nested_tensor_from_jagged values = x._values From 9ef9f1bd660418cc20ab7c0f54b89aca9958ca05 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 09:18:17 +0100 Subject: [PATCH 17/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 54c3ed869..dad689bab 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -65,7 +65,7 @@ def default_device(): [True, False, None], [True, False, 4], [True, False, 16], - [True, "default", 0], + [True, "default", None], ], ) @pytest.mark.skipif( From 26642cac21d94b8ef9f5fea61d8e4c6a77addd44 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 10:26:32 +0100 Subject: [PATCH 18/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index f94700b1e..641233657 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -76,7 +76,7 @@ class TestTo: def test_to( self, benchmark, consolidated, td, default_device, compile_mode, num_threads ): - tensordict_logger.info("td size (bytes)", td.bytes()) + tensordict_logger.info(f"td size (Gb) {td.bytes() / 1024 / 1024 / 1024 :.2f} GB") if consolidated: td = td.consolidate() @@ -94,7 +94,7 @@ def to(td, num_threads): def test_to_njt( self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads ): - tensordict_logger.info("njt_td size (bytes)", njt_td.bytes()) + tensordict_logger.info(f"njtd size (Gb) {td.bytes() / 1024 / 1024 / 1024 :.2f} GB") if consolidated: njt_td = njt_td.consolidate() From a9dbfb30c92f95c02cb1458cda9fb771efa2cf60 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 10:27:48 +0100 Subject: [PATCH 19/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 641233657..3771e68c0 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -76,7 +76,9 @@ class TestTo: def test_to( self, benchmark, consolidated, td, default_device, compile_mode, num_threads ): - tensordict_logger.info(f"td size (Gb) {td.bytes() / 1024 / 1024 / 1024 :.2f} GB") + tensordict_logger.info( + f"td size {td.bytes() / 1024 / 1024:.2f} Mb" + ) if consolidated: td = td.consolidate() @@ -94,7 +96,9 @@ def to(td, num_threads): def test_to_njt( self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads ): - tensordict_logger.info(f"njtd size (Gb) {td.bytes() / 1024 / 1024 / 1024 :.2f} GB") + tensordict_logger.info( + f"njtd size {td.bytes() / 1024 / 1024 :.2f} Mb" + ) if consolidated: njt_td = njt_td.consolidate() From 5239f2e223e105dca4d9524005e0345e0dc096bc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 10:28:04 +0100 Subject: [PATCH 20/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 3771e68c0..672f10ac9 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -97,7 +97,7 @@ def test_to_njt( self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads ): tensordict_logger.info( - f"njtd size {td.bytes() / 1024 / 1024 :.2f} Mb" + f"njtd size {njtd.bytes() / 1024 / 1024 :.2f} Mb" ) if consolidated: njt_td = njt_td.consolidate() From fb2098894d39b89bd25e83a0c3b5640e1d3af75c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 10:28:31 +0100 Subject: [PATCH 21/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 672f10ac9..f7571bfbb 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -97,7 +97,7 @@ def test_to_njt( self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads ): tensordict_logger.info( - f"njtd size {njtd.bytes() / 1024 / 1024 :.2f} Mb" + f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb" ) if consolidated: njt_td = njt_td.consolidate() From 60f029b4bb3b519e3b3ffb547179ffc3ddf67f6f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 14:11:24 +0100 Subject: [PATCH 22/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 70 ++++++++++++++++++++++------------- tensordict/base.py | 2 +- tensordict/tensorclass.py | 1 + 3 files changed, 47 insertions(+), 26 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index f7571bfbb..b7df5e460 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -4,35 +4,41 @@ # LICENSE file in the root directory of this source tree. import argparse +from typing import Any import pytest import torch from packaging import version -from tensordict import TensorDict +from tensordict import tensorclass, TensorDict from tensordict.utils import logger as tensordict_logger TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) +@tensorclass +class NJT: + _values: torch.Tensor + _offsets: torch.Tensor + _lengths: torch.Tensor + njt_shape: Any = None + + @classmethod + def from_njt(cls, njt_tensor): + return NJT( + _values=njt_tensor._values, + _offsets=njt_tensor._offsets, + _lengths=njt_tensor._lengths, + njt_shape=njt_tensor.size(0), + ) + + @pytest.fixture(autouse=True, scope="function") def empty_compiler_cache(): torch._dynamo.reset_code_caches() yield -@pytest.fixture -def td(): - return TensorDict( - { - str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)} - for i in range(16) - }, - batch_size=[16], - device="cpu", - ) - - def _make_njt(): lengths = torch.arange(24, 1, -1) offsets = torch.cat([lengths[:1] * 0, lengths]).cumsum(0) @@ -41,14 +47,27 @@ def _make_njt(): ) -@pytest.fixture -def njt_td(): +def _njt_td(): return TensorDict( {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, device="cpu", ) +@pytest.fixture +def njt_td(): + return _njt_td() + + +@pytest.fixture +def td(): + njtd = _njt_td() + for k0, v0 in njtd.items(): + for k1, v1 in v0.items(): + njtd[k0, k1] = NJT.from_njt(v1) + return njtd + + @pytest.fixture def default_device(): if torch.cuda.is_available(): @@ -64,8 +83,9 @@ def default_device(): [ [False, False, None], [True, False, None], - [True, False, 4], - [True, False, 16], + ["within", False, None], + # [True, False, 4], + # [True, False, 16], [True, "default", None], ], ) @@ -76,13 +96,13 @@ class TestTo: def test_to( self, benchmark, consolidated, td, default_device, compile_mode, num_threads ): - tensordict_logger.info( - f"td size {td.bytes() / 1024 / 1024:.2f} Mb" - ) - if consolidated: + tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") + if consolidated is True: td = td.consolidate() def to(td, num_threads): + if consolidated == "within": + td = td.consolidate() return td.to(default_device, num_threads=num_threads) if compile_mode: @@ -96,13 +116,13 @@ def to(td, num_threads): def test_to_njt( self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads ): - tensordict_logger.info( - f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb" - ) - if consolidated: + tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") + if consolidated is True: njt_td = njt_td.consolidate() def to(td, num_threads): + if consolidated == "within": + td = td.consolidate() return td.to(default_device, num_threads=num_threads) if compile_mode: diff --git a/tensordict/base.py b/tensordict/base.py index 09f047e5d..d7bf21116 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3574,7 +3574,7 @@ def saved_path(self): ) # Generic method to get a class metadata - def _reduce_get_metadata(self): + def _reduce_get_metadata(self) -> dict: return { "device": str(self.device) if self.device is not None else None, "names": self.names, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 8906eefd4..fa820ac7a 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -134,6 +134,7 @@ def __subclasscheck__(self, subclass): "_multithread_rebuild", # rebuild checks if self is a non tensor "_propagate_lock", "_propagate_unlock", + "_reduce_get_metadata", "_values_list", "data_ptr", "dim", From a3dbf308f40fcc38ea38b9b314be3322f8175145 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 14:16:27 +0100 Subject: [PATCH 23/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index b7df5e460..8b528d2d6 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -100,10 +100,12 @@ def test_to( if consolidated is True: td = td.consolidate() - def to(td, num_threads): - if consolidated == "within": - td = td.consolidate() - return td.to(default_device, num_threads=num_threads) + if consolidated == "within": + def to(td, num_threads): + return td.consolidate().to(default_device, num_threads=num_threads) + else: + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) if compile_mode: to = torch.compile(to, mode=compile_mode) @@ -120,10 +122,12 @@ def test_to_njt( if consolidated is True: njt_td = njt_td.consolidate() - def to(td, num_threads): - if consolidated == "within": - td = td.consolidate() - return td.to(default_device, num_threads=num_threads) + if consolidated == "within": + def to(td, num_threads): + return td.consolidate().to(default_device, num_threads=num_threads) + else: + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) if compile_mode: to = torch.compile(to, mode=compile_mode) From de06570a438933e8cf6a734883ead5591e9a69aa Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 11:46:44 +0100 Subject: [PATCH 24/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 12 ++- tensordict/base.py | 135 +++++++++++++++++++--------------- tensordict/tensorclass.py | 2 + tensordict/utils.py | 44 +++++++++++ 4 files changed, 133 insertions(+), 60 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 8b528d2d6..6973592b7 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -99,11 +99,15 @@ def test_to( tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") if consolidated is True: td = td.consolidate() + pin_mem = default_device.type == "cuda" if consolidated == "within": + def to(td, num_threads): - return td.consolidate().to(default_device, num_threads=num_threads) + return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to(default_device, num_threads=num_threads) + else: + def to(td, num_threads): return td.to(default_device, num_threads=num_threads) @@ -121,11 +125,15 @@ def test_to_njt( tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") if consolidated is True: njt_td = njt_td.consolidate() + pin_mem = default_device.type == "cuda" if consolidated == "within": + def to(td, num_threads): - return td.consolidate().to(default_device, num_threads=num_threads) + return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to(default_device, num_threads=num_threads) + else: + def to(td, num_threads): return td.to(default_device, num_threads=num_threads) diff --git a/tensordict/base.py b/tensordict/base.py index d7bf21116..026b284ca 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -93,6 +93,9 @@ TensorDictFuture, unravel_key, unravel_key_list, + view_and_pad, + view_cat_split, + view_old_as_new, ) from torch import distributed as dist, multiprocessing as mp, nn, Tensor from torch.nn.parameter import UninitializedTensorMixin @@ -3645,8 +3648,9 @@ def assign( total_key = key if isinstance(key, tuple) else (key,) total_key = track_key + total_key cls = type(value) - if issubclass(cls, torch.Tensor): + if cls is Tensor or issubclass(cls, Tensor): pass + # must go before is_tensor_collection elif _is_non_tensor(cls): if requires_metadata: metadata_dict["non_tensors"][key] = ( @@ -3663,19 +3667,22 @@ def assign( "leaves": {}, "cls_metadata": value._reduce_get_metadata(), } - local_assign = partial( - assign, + local_assign = lambda key, value: assign( + key, + value, track_key=total_key, metadata_dict=metadata_dict_key, flat_size=flat_size, ) - value._fast_apply( + r = value._fast_apply( local_assign, named=True, nested_keys=True, call_on_nested=True, is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, + filter_empty=True, ) + assert r is None return # Tensors: DTensor, nested and then regular if hasattr(value, "full_tensor"): @@ -3772,6 +3779,7 @@ def consolidate( share_memory: bool = False, pin_memory: bool = False, metadata: bool = False, + set_on_tensor: bool = False, ) -> None: """Consolidates the tensordict content in a single storage for fast serialization. @@ -3892,12 +3900,6 @@ def consolidate( offsets = torch.tensor([0] + flat_size).cumsum(0).tolist() - def view_old_as_new(v, oldv): - v = v.view(oldv.dtype) - if v.numel() > oldv.numel(): - return v[: oldv.numel()].view(oldv.shape) - return v.view(oldv.shape) - if num_threads > 0: def assign( @@ -3920,14 +3922,22 @@ def assign( pad = exp_length - v_pad.numel() if pad: v_pad = torch.cat([v_pad, v_pad.new_zeros(pad)]) - storage[start:stop].copy_(v_pad, non_blocking=non_blocking) - storage_slice = storage[start:stop] + storage_slice.copy_(v_pad, non_blocking=non_blocking) + shape, dtype = v.shape, v.dtype new_v = storage_slice.view(dtype) if pad: new_v = new_v[: v.numel()] new_v = new_v.view(shape) + if set_on_tensor: + v.set_( + new_v.untyped_storage(), + storage_offset=new_v.storage_offset(), + stride=new_v.stride(), + size=new_v.size(), + ) + return flat_dict[k] = new_v njts = {} @@ -3961,76 +3971,81 @@ def assign( stop=offsets[i + 1], njts=njts, ) - for njt_key, njt in njts.items(): - newkey = njt_key[:-1] + (njt_key[-1].replace("", ""),) - njt_key_values = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - njt_key_offset = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - njt_key_lengths = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - val = _rebuild_njt_from_njt( - njt, - values=flat_dict.pop(njt_key_values), - offsets=flat_dict.pop(njt_key_offset), - lengths=flat_dict.pop(njt_key_lengths, None), - ) - del flat_dict[njt_key] - flat_dict[newkey] = val + if not set_on_tensor: + for njt_key, njt in njts.items(): + newkey = njt_key[:-1] + (njt_key[-1].replace("", ""),) + njt_key_values = njt_key[:-1] + ( + njt_key[-1].replace("", ""), + ) + njt_key_offset = njt_key[:-1] + ( + njt_key[-1].replace("", ""), + ) + njt_key_lengths = njt_key[:-1] + ( + njt_key[-1].replace("", ""), + ) + val = _rebuild_njt_from_njt( + njt, + values=flat_dict.pop(njt_key_values), + offsets=flat_dict.pop(njt_key_offset), + lengths=flat_dict.pop(njt_key_lengths, None), + ) + del flat_dict[njt_key] + flat_dict[newkey] = val if non_blocking and device.type != "cuda": # sync if needed self._sync_all() + if set_on_tensor: + return self else: - def _view_and_pad(tensor): - result = tensor.view(-1).view(torch.uint8) - # result must always have a multiple of 8 elements - pad = 0 - if need_padding: - pad = result.numel() % 8 - if pad != 0: - result = torch.cat([result, result.new_zeros(8 - pad)]) - return result, pad - items = [] for v in flat_dict.values(): if v.is_nested: + items.append(None) continue if v.device != storage.device: v = v.to(storage.device, non_blocking=non_blocking) - stride = v.stride() - if (stride and stride[-1] != 1) or v.storage_offset(): + if is_dynamo_compiling(): v = v.clone(memory_format=torch.contiguous_format) - v, pad = _view_and_pad(v) + else: + stride = v.stride() + if (stride and stride[-1] != 1) or v.storage_offset(): + v = v.clone(memory_format=torch.contiguous_format) + # v, pad = _view_and_pad(v) items.append(v) - if non_blocking and device.type != "cuda": - # sync if needed - self._sync_all() - torch.cat(items, out=storage) - for v, (k, oldv) in _zip_strict( - storage.split(flat_size), list(flat_dict.items()) - ): + + items = view_cat_split( + self, + items, + storage, + need_padding, + non_blocking, + device, + flat_size, + set_on_tensor, + ) + if set_on_tensor: + return self + + for k, v in _zip_strict(flat_dict.keys(), items): if not k[-1].startswith("<"): - flat_dict[k] = view_old_as_new(v, oldv) + flat_dict[k] = v elif k[-1].startswith(""): # NJT/NT always comes before offsets/shapes - nt = oldv + nt = flat_dict[k] assert not v.numel() nt_lengths = None del flat_dict[k] elif k[-1].startswith(""): - nt_vaues = view_old_as_new(v, oldv) + nt_vaues = v del flat_dict[k] elif k[-1].startswith(""): - nt_lengths = view_old_as_new(v, oldv) + nt_lengths = v del flat_dict[k] elif k[-1].startswith(""): newk = k[:-1] + (k[-1].replace("", ""),) - nt_offsets = view_old_as_new(v, oldv) + nt_offsets = v del flat_dict[k] val = _rebuild_njt_from_njt( @@ -4044,7 +4059,7 @@ def _view_and_pad(tensor): # another nested tensor. del nt, nt_vaues, nt_offsets, nt_lengths else: - flat_dict[k] = view_old_as_new(v, oldv) + flat_dict[k] = v def assign_val(key, val): if isinstance(key, str): @@ -4060,12 +4075,16 @@ def assign_val(key, val): device = None else: device = None + if inplace: + result = self + else: + result = None result = self._fast_apply( assign_val, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, - out=self if inplace else None, + out=result, device=device, ) result._consolidated = {"storage": storage, "metadata": metadata_dict} diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index fa820ac7a..e3eb4aaaa 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1166,6 +1166,8 @@ def check_out(kwargs, result): return self def deliver_result(result): + if result is None: + return if isinstance(result, TensorDictBase) and not check_out(kwargs, result): if not is_dynamo_compiling(): non_tensordict = super(type(self), self).__getattribute__( diff --git a/tensordict/utils.py b/tensordict/utils.py index 793a284ae..55fc96855 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2709,3 +2709,47 @@ def _to_escape_compile( @_to_escape_compile.register_fake def _(storage: torch.Tensor, device: torch.device, pin_memory: bool) -> torch.Tensor: return torch.empty_like(storage, device=device) + + +def view_and_pad(tensor: torch.Tensor, need_padding: bool) -> torch.Tensor: + result = tensor.view(-1).view(torch.uint8) + # result must always have a multiple of 8 elements + if need_padding: + pad = result.numel() % 8 + if pad != 0: + result = torch.cat([result, result.new_zeros(8 - pad)]) + return result + + +def view_old_as_new(v: torch.Tensor, oldv: torch.Tensor) -> torch.Tensor: + if oldv is None: + return v + v = v.view(oldv.dtype) + if v.numel() > oldv.numel(): + return v[: oldv.numel()].view(oldv.shape) + return v.view(oldv.shape) + + +@torch.compiler.disable() +def view_cat_split( + td, items, storage, need_padding, non_blocking, device, flat_size, set_on_tensor +): + items_flat = [view_and_pad(v, need_padding) for v in items] + if non_blocking and device.type != "cuda": + # sync if needed + td._sync_all() + torch.cat(items_flat, out=storage) + # TODO: breaks with NJT + result = [ + view_old_as_new(v, oldv) + for (v, oldv) in zip(storage.split(flat_size), items, strict=True) + ] + if set_on_tensor: + for t_dest, t_src in zip(result, items): + t_src.set_( + t_dest.untyped_storage(), + storage_offset=t_dest.storage_offset(), + stride=t_dest.stride(), + size=t_dest.size(), + ) + return result From 11e954634c71eb7631969071cf8bff4f743b9344 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 12:51:24 +0100 Subject: [PATCH 25/37] Update [ghstack-poisoned] --- tensordict/base.py | 1 - tensordict/utils.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 026b284ca..c152df5e7 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -4012,7 +4012,6 @@ def assign( stride = v.stride() if (stride and stride[-1] != 1) or v.storage_offset(): v = v.clone(memory_format=torch.contiguous_format) - # v, pad = _view_and_pad(v) items.append(v) items = view_cat_split( diff --git a/tensordict/utils.py b/tensordict/utils.py index 55fc96855..240f44b09 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2734,7 +2734,7 @@ def view_old_as_new(v: torch.Tensor, oldv: torch.Tensor) -> torch.Tensor: def view_cat_split( td, items, storage, need_padding, non_blocking, device, flat_size, set_on_tensor ): - items_flat = [view_and_pad(v, need_padding) for v in items] + items_flat = [view_and_pad(v, need_padding) for v in items if v is not None] if non_blocking and device.type != "cuda": # sync if needed td._sync_all() From b348657270569aa582d27054b6f103e5a4cedaf4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 13:02:33 +0100 Subject: [PATCH 26/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 8 ++++++-- tensordict/base.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 6973592b7..e9200ddf4 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -104,7 +104,9 @@ def test_to( if consolidated == "within": def to(td, num_threads): - return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to(default_device, num_threads=num_threads) + return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to( + default_device, num_threads=num_threads + ) else: @@ -130,7 +132,9 @@ def test_to_njt( if consolidated == "within": def to(td, num_threads): - return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to(default_device, num_threads=num_threads) + return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to( + default_device, num_threads=num_threads + ) else: diff --git a/tensordict/base.py b/tensordict/base.py index c152df5e7..ffc23d86a 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -4027,7 +4027,7 @@ def assign( if set_on_tensor: return self - for k, v in _zip_strict(flat_dict.keys(), items): + for k, v in _zip_strict(list(flat_dict.keys()), items): if not k[-1].startswith("<"): flat_dict[k] = v elif k[-1].startswith(""): From 506aa62b744de65bf1862b668443aba8f0dfefdd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 13:04:14 +0100 Subject: [PATCH 27/37] Update [ghstack-poisoned] --- tensordict/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensordict/utils.py b/tensordict/utils.py index 240f44b09..1cee55a2f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2746,6 +2746,9 @@ def view_cat_split( ] if set_on_tensor: for t_dest, t_src in zip(result, items): + if t_src is None: + # njt is decomposed + continue t_src.set_( t_dest.untyped_storage(), storage_offset=t_dest.storage_offset(), From 71b4a0069b32693bdece9b692ddc407f059d11f0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 13:21:25 +0100 Subject: [PATCH 28/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 8 +++--- tensordict/tensorclass.py | 52 +++++++++++++++++------------------ 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index e9200ddf4..715c84f46 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -97,9 +97,9 @@ def test_to( self, benchmark, consolidated, td, default_device, compile_mode, num_threads ): tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") - if consolidated is True: - td = td.consolidate() pin_mem = default_device.type == "cuda" + if consolidated is True: + td = td.consolidate(pin_memory=pin_mem, set_on_tensor=True) if consolidated == "within": @@ -125,9 +125,9 @@ def test_to_njt( self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads ): tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") - if consolidated is True: - njt_td = njt_td.consolidate() pin_mem = default_device.type == "cuda" + if consolidated is True: + njt_td = njt_td.consolidate(pin_memory=pin_mem, set_on_tensor=True) if consolidated == "within": diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index e3eb4aaaa..e7f5faade 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1144,6 +1144,30 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False): + def check_out(kwargs, result): + out = kwargs.get("out") + if out is result: + # No need to transform output + return True + return False + + def deliver_result(self, result, kwargs): + if result is None: + return + if isinstance(result, TensorDictBase) and not check_out(kwargs, result): + if not is_dynamo_compiling(): + non_tensordict = super(type(self), self).__getattribute__( + "_non_tensordict" + ) + else: + non_tensordict = self._non_tensordict + non_tensordict = dict(non_tensordict) + if copy_non_tensor: + # use tree_map to copy + non_tensordict = tree_map(lambda x: x, non_tensordict) + return self._from_tensordict(result, non_tensordict) + return result + def wrapped_func(self, *args, **kwargs): if not is_dynamo_compiling(): td = super(type(self), self).__getattribute__("_tensordict") @@ -1155,36 +1179,12 @@ def wrapped_func(self, *args, **kwargs): if no_wrap: return result - def check_out(kwargs, result): - out = kwargs.get("out") - if out is result: - # No need to transform output - return True - return False - if result is td: return self - def deliver_result(result): - if result is None: - return - if isinstance(result, TensorDictBase) and not check_out(kwargs, result): - if not is_dynamo_compiling(): - non_tensordict = super(type(self), self).__getattribute__( - "_non_tensordict" - ) - else: - non_tensordict = self._non_tensordict - non_tensordict = dict(non_tensordict) - if copy_non_tensor: - # use tree_map to copy - non_tensordict = tree_map(lambda x: x, non_tensordict) - return self._from_tensordict(result, non_tensordict) - return result - if isinstance(result, tuple): - return tuple(deliver_result(r) for r in result) - return deliver_result(result) + return tuple(deliver_result(self, r, kwargs) for r in result) + return deliver_result(self, result, kwargs) return wrapped_func From ff1fc260b1cae42f94001c0c5cefa5a36ede6c0e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 13:23:30 +0100 Subject: [PATCH 29/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 715c84f46..47ec626fe 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -86,7 +86,7 @@ def default_device(): ["within", False, None], # [True, False, 4], # [True, False, 16], - [True, "default", None], + # [True, "default", None], ], ) @pytest.mark.skipif( From 9c7ed32e88b44f9bf350c3afd252f989df4c50b7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 13:37:44 +0100 Subject: [PATCH 30/37] Update [ghstack-poisoned] --- tensordict/base.py | 49 +++++++++++++++++++--------------------------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index ffc23d86a..888daac21 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3580,12 +3580,12 @@ def saved_path(self): def _reduce_get_metadata(self) -> dict: return { "device": str(self.device) if self.device is not None else None, - "names": self.names, + "names": self._maybe_names(), "batch_size": list(self.batch_size), "is_locked": self._is_locked, } - # @cache # noqa: B019 + @torch.compile(mode="reduce-overhead") def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): """Returns a nested dictionary of metadata, a flat Dict[NestedKey, Tensor] containing tensor data and a list of tensor sizes.""" if dtype is NO_DEFAULT: @@ -3674,7 +3674,7 @@ def assign( metadata_dict=metadata_dict_key, flat_size=flat_size, ) - r = value._fast_apply( + value._fast_apply( local_assign, named=True, nested_keys=True, @@ -3682,7 +3682,6 @@ def assign( is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, filter_empty=True, ) - assert r is None return # Tensors: DTensor, nested and then regular if hasattr(value, "full_tensor"): @@ -10618,14 +10617,26 @@ def _to_consolidated( storage = self._consolidated["storage"] storage_cast = _to_escape_compile(storage, device=device, pin_memory=pin_memory) + _consolidated = { + "storage": storage_cast, + } + if "metadata" in self._consolidated: + # faster than deepcopy + def copy_dict(d): + return { + k: v if not isinstance(v, dict) else copy_dict(v) + for k, v in d.items() + } + + _consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) if compilable: result = self._to_consolidated_compile( - device=device, num_threads=num_threads, storage_cast=storage_cast + device=device, num_threads=num_threads, storage_cast=storage_cast, _consolidated=_consolidated, ) else: result = self._to_consolidated_eager( - device=device, num_threads=num_threads, storage_cast=storage_cast + device=device, num_threads=num_threads, storage_cast=storage_cast, _consolidated=_consolidated, ) if non_blocking in (False, None): @@ -10642,7 +10653,7 @@ def _to_consolidated( return result - def _to_consolidated_eager(self, *, device, num_threads, storage_cast): + def _to_consolidated_eager(self, *, device, num_threads, storage_cast, _consolidated): untyped_storage = storage_cast.untyped_storage() @@ -10702,19 +10713,10 @@ def set_(x): result = self._fast_apply( set_, device=torch.device(device), num_threads=num_threads ) - result._consolidated = {"storage": storage_cast} - if "metadata" in self._consolidated: - # faster than deepcopy - def copy_dict(d): - return { - k: v if not isinstance(v, dict) else copy_dict(v) - for k, v in d.items() - } - - result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) + result._consolidated = _consolidated return result - def _to_consolidated_compile(self, *, device, num_threads, storage_cast): + def _to_consolidated_compile(self, *, device, num_threads, storage_cast, _consolidated): def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()): root = False @@ -10754,17 +10756,6 @@ def split_storage(consolidated): # unspecified num_threads should mean 0 num_threads = 0 - _consolidated = {"storage": storage_cast} - if "metadata" in self._consolidated: - # faster than deepcopy - def copy_dict(d): - return { - k: v if not isinstance(v, dict) else copy_dict(v) - for k, v in d.items() - } - - _consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) - slice_map = split_storage(_consolidated) def view_as(src, dest): From 26affa02ebe82a072bd482338fdb262053dead24 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 13:40:17 +0100 Subject: [PATCH 31/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 47ec626fe..3751f2d24 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -99,12 +99,12 @@ def test_to( tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") pin_mem = default_device.type == "cuda" if consolidated is True: - td = td.consolidate(pin_memory=pin_mem, set_on_tensor=True) + td = td.consolidate(pin_memory=pin_mem) if consolidated == "within": def to(td, num_threads): - return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to( + return td.consolidate(pin_memory=pin_mem).to( default_device, num_threads=num_threads ) @@ -127,12 +127,12 @@ def test_to_njt( tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") pin_mem = default_device.type == "cuda" if consolidated is True: - njt_td = njt_td.consolidate(pin_memory=pin_mem, set_on_tensor=True) + njt_td = njt_td.consolidate(pin_memory=pin_mem) if consolidated == "within": def to(td, num_threads): - return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to( + return td.consolidate(pin_memory=pin_mem).to( default_device, num_threads=num_threads ) From e0e1ddc9feaae8b52d550a536e81991d2557cd3e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 13:58:38 +0100 Subject: [PATCH 32/37] Update [ghstack-poisoned] --- tensordict/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index 888daac21..554697cd9 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3585,7 +3585,6 @@ def _reduce_get_metadata(self) -> dict: "is_locked": self._is_locked, } - @torch.compile(mode="reduce-overhead") def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): """Returns a nested dictionary of metadata, a flat Dict[NestedKey, Tensor] containing tensor data and a list of tensor sizes.""" if dtype is NO_DEFAULT: From 5d26eede6137d31bfe1159dbb682e9d5915217a5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 14:08:56 +0100 Subject: [PATCH 33/37] Update [ghstack-poisoned] --- tensordict/tensorclass.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index e7f5faade..8ef40304e 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -871,7 +871,6 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 f"Expected a TensorDictBase instance but got {type(tensordict)}" ) # Validating keys of tensordict - # tensordict = tensordict.copy() tensor_keys = tensordict.keys() # TODO: compile doesn't like set() over an arbitrary object if is_dynamo_compiling(): @@ -891,10 +890,11 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 exp_keys = set(cls.__expected_keys__) if non_tensordict is not None: nontensor_keys = set(non_tensordict.keys()) + total_keys = tensor_keys.union(nontensor_keys) else: nontensor_keys = set() non_tensordict = {} - total_keys = tensor_keys.union(nontensor_keys) + total_keys = tensor_keys for key in nontensor_keys: if key not in tensor_keys: continue @@ -922,7 +922,7 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 tc.__dict__["_non_tensordict"] = non_tensordict # since we aren't calling the dataclass init method, we need to manually check # whether a __post_init__ method has been defined and invoke it if so - if hasattr(tc, "__post_init__"): + if hasattr(cls, "__post_init__"): tc.__post_init__() return tc else: From ccb4dcc4deb8fb82b3cdcd655caa5d2227fe1336 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 14:14:52 +0100 Subject: [PATCH 34/37] Update [ghstack-poisoned] --- tensordict/tensorclass.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 8ef40304e..ae6643555 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -858,7 +858,7 @@ def get_parent_locals(cls, localns=localns): cls._type_hints = None -def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 +def _from_tensordict(cls, tensordict, non_tensordict=None, safe=True): # noqa: D417 """Tensor class wrapper to instantiate a new tensor class object. Args: @@ -866,7 +866,7 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 non_tensordict (dict): Dictionary with non-tensor and nested tensor class objects """ - if not isinstance(tensordict, TensorDictBase): + if safe and not isinstance(tensordict, TensorDictBase): raise RuntimeError( f"Expected a TensorDictBase instance but got {type(tensordict)}" ) @@ -918,8 +918,8 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 # empty tensordict and writing values to it. we can skip this because we already # have a tensordict to use as the underlying tensordict tc = cls.__new__(cls) - tc.__dict__["_tensordict"] = tensordict - tc.__dict__["_non_tensordict"] = non_tensordict + tc.__dict__.update({"_tensordict": tensordict, + "_non_tensordict": non_tensordict}) # since we aren't calling the dataclass init method, we need to manually check # whether a __post_init__ method has been defined and invoke it if so if hasattr(cls, "__post_init__"): @@ -1162,10 +1162,10 @@ def deliver_result(self, result, kwargs): else: non_tensordict = self._non_tensordict non_tensordict = dict(non_tensordict) - if copy_non_tensor: + if copy_non_tensor and non_tensordict: # use tree_map to copy non_tensordict = tree_map(lambda x: x, non_tensordict) - return self._from_tensordict(result, non_tensordict) + return self._from_tensordict(result, non_tensordict, safe=False) return result def wrapped_func(self, *args, **kwargs): From 1787e594af5da713ade832f17190424e2a22c1e1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 14:16:45 +0100 Subject: [PATCH 35/37] Update [ghstack-poisoned] --- tensordict/tensorclass.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index ae6643555..f59b239f4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1145,11 +1145,8 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False): def check_out(kwargs, result): - out = kwargs.get("out") - if out is result: - # No need to transform output - return True - return False + # No need to transform output if True + return kwargs.get("out") is result def deliver_result(self, result, kwargs): if result is None: From f7f71aa32599e1615052963faf5a01308142f92a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 18 Oct 2024 14:24:28 +0100 Subject: [PATCH 36/37] Update [ghstack-poisoned] --- tensordict/tensorclass.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index f59b239f4..cacdd6905 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -570,7 +570,7 @@ def __torch_function__( setattr(cls, method_name, getattr(TensorDict, method_name)) for method_name in _FALLBACK_METHOD_FROM_TD: if not hasattr(cls, method_name): - setattr(cls, method_name, _wrap_td_method(method_name)) + setattr(cls, method_name, _wrap_td_method(method_name, force_wrap=True)) for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP: if not hasattr(cls, method_name): setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True)) @@ -1143,15 +1143,11 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 return wrapper -def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False): - def check_out(kwargs, result): - # No need to transform output if True - return kwargs.get("out") is result - +def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False, force_wrap=False): def deliver_result(self, result, kwargs): if result is None: return - if isinstance(result, TensorDictBase) and not check_out(kwargs, result): + if (force_wrap or isinstance(result, TensorDictBase)) and kwargs.get("out") is not result: if not is_dynamo_compiling(): non_tensordict = super(type(self), self).__getattribute__( "_non_tensordict" From e8178cea5f17940451620ab0739ab95d07ee8b2f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 23 Oct 2024 15:22:14 -0700 Subject: [PATCH 37/37] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 81 ++++- tensordict/_reductions.py | 2 +- tensordict/_td.py | 7 +- tensordict/base.py | 617 +++++++++++----------------------- tensordict/tensorclass.py | 13 +- tensordict/utils.py | 40 +-- test/test_tensordict.py | 10 +- 7 files changed, 305 insertions(+), 465 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 3751f2d24..38d18ac38 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import time from typing import Any import pytest @@ -25,17 +26,17 @@ class NJT: @classmethod def from_njt(cls, njt_tensor): - return NJT( + return cls( _values=njt_tensor._values, _offsets=njt_tensor._offsets, _lengths=njt_tensor._lengths, njt_shape=njt_tensor.size(0), - ) + ).clone() @pytest.fixture(autouse=True, scope="function") def empty_compiler_cache(): - torch._dynamo.reset_code_caches() + torch.compiler.reset() yield @@ -49,7 +50,8 @@ def _make_njt(): def _njt_td(): return TensorDict( - {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, + # {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, + {str(i): _make_njt() for i in range(128)}, device="cpu", ) @@ -63,8 +65,9 @@ def njt_td(): def td(): njtd = _njt_td() for k0, v0 in njtd.items(): - for k1, v1 in v0.items(): - njtd[k0, k1] = NJT.from_njt(v1) + njtd[k0] = NJT.from_njt(v0) + # for k1, v1 in v0.items(): + # njtd[k0, k1] = NJT.from_njt(v1) return njtd @@ -78,6 +81,56 @@ def default_device(): pytest.skip("CUDA/MPS is not available") +@pytest.mark.parametrize( + "compile_mode,num_threads", + [ + [False, None], + # [False, 4], + # [False, 16], + ["default", None], + ["reduce-overhead", None], + ], +) +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" +) +class TestConsolidate: + def test_consolidate(self, benchmark, td, compile_mode, num_threads): + tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") + + def consolidate(td, num_threads): + return td.consolidate(num_threads=num_threads) + + if compile_mode: + consolidate = torch.compile( + consolidate, mode=compile_mode, dynamic=True, fullgraph=True + ) + + t0 = time.time() + consolidate(td, num_threads=num_threads) + elapsed = time.time() - t0 + tensordict_logger.info(f"elapsed time first call: {elapsed:.2f} sec") + + for _ in range(3): + consolidate(td, num_threads=num_threads) + + benchmark(consolidate, td, num_threads) + + def test_to_njt(self, benchmark, njt_td, compile_mode, num_threads): + tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") + + def consolidate(td, num_threads): + return td.consolidate(num_threads=num_threads) + + if compile_mode: + consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True) + + for _ in range(3): + consolidate(njt_td, num_threads=num_threads) + + benchmark(consolidate, njt_td, num_threads) + + @pytest.mark.parametrize( "consolidated,compile_mode,num_threads", [ @@ -86,7 +139,7 @@ def default_device(): ["within", False, None], # [True, False, 4], # [True, False, 16], - # [True, "default", None], + [True, "default", None], ], ) @pytest.mark.skipif( @@ -114,7 +167,7 @@ def to(td, num_threads): return td.to(default_device, num_threads=num_threads) if compile_mode: - to = torch.compile(to, mode=compile_mode) + to = torch.compile(to, mode=compile_mode, dynamic=True) for _ in range(3): to(td, num_threads=num_threads) @@ -142,7 +195,7 @@ def to(td, num_threads): return td.to(default_device, num_threads=num_threads) if compile_mode: - to = torch.compile(to, mode=compile_mode) + to = torch.compile(to, mode=compile_mode, dynamic=True) for _ in range(3): to(njt_td, num_threads=num_threads) @@ -153,6 +206,14 @@ def to(td, num_threads): if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main( - [__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"] + [ + __file__, + "--capture", + "no", + "--exitfirst", + "--benchmark-group-by", + "func", + "-vvv", + ] + unknown ) diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index 7234a42bd..c1a793f4e 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -138,7 +138,7 @@ def _make_td(cls, state): def _reduce_td(data: TensorDict): consolidated = getattr(data, "_consolidated", None) - if consolidated and consolidated["metadata"] is not None: + if isinstance(consolidated, dict): storage = consolidated["storage"] storge_metadata = consolidated["metadata"] return ( diff --git a/tensordict/_td.py b/tensordict/_td.py index 4387839b5..1252621d1 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -4210,7 +4210,7 @@ def _iter(): if self.leaves_only: for key in self._keys(): target_class = self.tensordict.entry_class(key) - if _is_tensor_collection(target_class): + if self.is_leaf(target_class): continue yield key else: @@ -4239,9 +4239,10 @@ def _iter_helper( # For lazy stacks value = value[0] cls = type(value) - is_leaf = self.is_leaf(cls) - if self.include_nested and not is_leaf: + is_tc = _is_tensor_collection(cls) + if self.include_nested and is_tc: yield from self._iter_helper(value, prefix=full_key) + is_leaf = self.is_leaf(cls) if not self.leaves_only or is_leaf: yield full_key diff --git a/tensordict/base.py b/tensordict/base.py index 554697cd9..17028d3b1 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -94,7 +94,6 @@ unravel_key, unravel_key_list, view_and_pad, - view_cat_split, view_old_as_new, ) from torch import distributed as dist, multiprocessing as mp, nn, Tensor @@ -3585,190 +3584,39 @@ def _reduce_get_metadata(self) -> dict: "is_locked": self._is_locked, } - def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): + def _reduce_vals_and_metadata(self, *, metadata): """Returns a nested dictionary of metadata, a flat Dict[NestedKey, Tensor] containing tensor data and a list of tensor sizes.""" - if dtype is NO_DEFAULT: - dtype = self.dtype - need_padding = dtype is None - # If the dtype is not unique (self.dtype is None) then we need the metadata - # because we need a custom unpickler - requires_metadata = requires_metadata | need_padding - - if requires_metadata: - # metadata is nested - metadata_dict = { - "cls": type(self).__name__, - "non_tensors": {}, - "leaves": {}, - "cls_metadata": self._reduce_get_metadata(), - } - else: - metadata_dict = None - - # flat_key_values is flat - flat_key_values = {} - - flat_size = [] - start = 0 - sorting_index = 0 - - def add_single_value(value, key, metadata_dict, dtype, shape, flat_size): - nonlocal start, sorting_index - n = value.element_size() * value.numel() - if need_padding: - pad = n % 8 - if pad != 0: - pad = 8 - pad - else: - pad = 0 - flat_size.append(n + pad) - stop = start + flat_size[-1] - if requires_metadata: - metadata_dict["leaves"][key] = ( - _DTYPE2STRDTYPE[dtype], - list(shape), - # _DEVICE2STRDEVICE[device], - start, - stop, - pad, - flat_size[-1], - sorting_index, - ) - sorting_index = sorting_index + 1 - start = stop + if not metadata: + return None, list(self.items(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)) + + metadata_dict = { + "cls": type(self).__name__, + "non_tensors": {}, + "leaves": {}, + "nodes": {}, + "cls_metadata": self._reduce_get_metadata(), + } - def assign( - key, - value, - track_key=(), - metadata_dict=metadata_dict, - flat_size=flat_size, - ): - total_key = key if isinstance(key, tuple) else (key,) - total_key = track_key + total_key - cls = type(value) - if cls is Tensor or issubclass(cls, Tensor): - pass - # must go before is_tensor_collection - elif _is_non_tensor(cls): - if requires_metadata: - metadata_dict["non_tensors"][key] = ( - value.data, - list(value.batch_size), - ) - return - elif _is_tensor_collection(cls): - metadata_dict_key = None - if requires_metadata: - metadata_dict_key = metadata_dict[key] = { - "cls": cls.__name__, - "non_tensors": {}, - "leaves": {}, - "cls_metadata": value._reduce_get_metadata(), - } - local_assign = lambda key, value: assign( - key, - value, - track_key=total_key, - metadata_dict=metadata_dict_key, - flat_size=flat_size, - ) - value._fast_apply( - local_assign, - named=True, - nested_keys=True, - call_on_nested=True, - is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, - filter_empty=True, + for k, it in self.items(True, False, is_leaf=_NESTED_TENSORS_AS_LISTS): + if _is_non_tensor(type(it)): + metadata_dict["non_tensors"][k] = ( + it.data, + list(it.batch_size), ) - return - # Tensors: DTensor, nested and then regular - if hasattr(value, "full_tensor"): - raise NotImplementedError("DTensor is not supported yet") - if getattr(value, "is_nested", False): - if value.layout is torch.jagged: - # Get the values - values = value._values - shape = [v if isinstance(v, int) else -1 for v in values.shape] - # Get the offsets - offsets = value._offsets - # Get the lengths - lengths = value._lengths - - # Now we're saving the two tensors - # We will rely on the fact that the writing order is preserved in python dict - # (since python 3.7). Later, we will read the NJT then the NJT offset in that order - # to do the allocation. - flat_key_values[_prefix_last_key(total_key, "")] = value - flat_size.append(0) - flat_key_values[_prefix_last_key(total_key, "")] = ( - values - ) - add_single_value( - values, - _prefix_last_key(key, ""), - metadata_dict, - values.dtype, - shape, - flat_size, - ) - # Lengths - if lengths is not None: - flat_key_values[ - _prefix_last_key(total_key, "") - ] = lengths - add_single_value( - lengths, - _prefix_last_key(key, ""), - metadata_dict, - lengths.dtype, - lengths.shape, - flat_size, - ) - # Offsets - flat_key_values[_prefix_last_key(total_key, "")] = ( - offsets - ) - add_single_value( - offsets, - _prefix_last_key(key, ""), - metadata_dict, - offsets.dtype, - offsets.shape, - flat_size, - ) - - else: - raise NotImplementedError( - "NST is not supported, please use layout=torch.jagged when building the nested tensor." - ) - return - flat_key_values[total_key] = value - add_single_value( - value, - key, - metadata_dict, - value.dtype, - value.shape, - # value.device, - flat_size, - ) - - self._fast_apply( - assign, - named=True, - call_on_nested=True, - nested_keys=True, - is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, - filter_empty=True, - ) - return metadata_dict, flat_key_values, flat_size, need_padding + elif _is_tensor_collection(type(it)): + metadata_dict["nodes"][k] = { + "cls": type(it).__name__, + "cls_metadata": it._reduce_get_metadata(), + } + else: + metadata_dict["leaves"][k] = it + return metadata_dict, None def consolidate( self, filename: Path | str | None = None, *, - num_threads=0, + num_threads: int | None = None, device: torch.device | None = None, non_blocking: bool = False, inplace: bool = False, @@ -3839,15 +3687,83 @@ def consolidate( if self.is_consolidated(): return self - ( - metadata_dict, - flat_dict, - flat_size, - need_padding, - ) = self._reduce_vals_and_metadata( - requires_metadata=filename is not None or metadata, dtype=None - ) - filesize = sum(flat_size) + metadata = metadata or filename + metadata_dict, items = self._reduce_vals_and_metadata(metadata=metadata) + + start = 0 + lengths = [] + swaps = [] + origs = [] + + def view_and_pad(key, tensor: torch.Tensor, lengths=lengths) -> torch.Tensor: + nonlocal start + if hasattr(tensor, "full_tensor"): + raise NotImplementedError("DTensor is not supported yet") + if getattr(tensor, "is_nested", False): + if tensor.layout is torch.jagged: + # Get the values + values = tensor._values + shape = [v if isinstance(v, int) else -1 for v in values.shape] + # Get the offsets + offsets = tensor._offsets + # Get the lengths + lengths = tensor._lengths + + # Now we're saving the two tensors + # We will rely on the fact that the writing order is preserved in python dict + # (since python 3.7). Later, we will read the NJT then the NJT offset in that order + # to do the allocation. + origs.append(tensor) + swaps.append(None) + + view_and_pad(_prefix_last_key(key, ""), values) + # Lengths + if lengths is not None: + view_and_pad( + _prefix_last_key(key, ""), lengths + ) + # Offsets + view_and_pad(_prefix_last_key(key, ""), offsets) + else: + raise NotImplementedError( + "Strided nested-tensors are not supported yet." + ) + if is_dynamo_compiling(): + # We should maybe clone by default but that seems a bit too harsh? + tensor = tensor.clone(memory_format=torch.contiguous_format) + else: + stride = tensor.stride() + if (stride and stride[-1] != 1) or tensor.storage_offset(): + tensor = tensor.clone(memory_format=torch.contiguous_format) + + origs.append(tensor) + swap = tensor.view(-1).view(torch.uint8) + # result must always have a multiple of 8 elements + pad = swap.numel() % 8 + if pad != 0: + swap = torch.cat([swap, swap.new_zeros(8 - pad)]) + n = swap.numel() + if metadata: + info = ( + _DTYPE2STRDTYPE[tensor.dtype], + list(tensor.shape), + start, + pad, + n, + ) + metadata_dict["leaves"][key] = info + start = start + n + lengths.append(n) + swaps.append(swap) + + if metadata: + for key, val in metadata_dict: + view_and_pad(key, val) + else: + for key, val in items: + view_and_pad(key, val) + + filesize = start device = torch.device(device) if device is not None else None if filename is None: storage = torch.empty( @@ -3894,205 +3810,64 @@ def consolidate( total_storage[-8:] = len_metadata total_storage[-8 - metadata_dict_json.numel() : -8] = metadata_dict_json storage = total_storage[:-suffix] - # assert len(storage.untyped_storage()) == filesize - - offsets = torch.tensor([0] + flat_size).cumsum(0).tolist() + if num_threads is None: + num_threads = 0 if num_threads > 0: - - def assign( - *, - k, - v, - start, - stop, - njts, - storage=storage, - non_blocking=non_blocking, - ): - """Reads a slice of the storage and assigns the resulting tensor in flat_dict.""" - # v may need padding - if k[-1].startswith(""): - njts[k] = v - return - v_pad = v.view(-1).view(torch.uint8) - exp_length = stop - start - pad = exp_length - v_pad.numel() - if pad: - v_pad = torch.cat([v_pad, v_pad.new_zeros(pad)]) - storage_slice = storage[start:stop] - storage_slice.copy_(v_pad, non_blocking=non_blocking) - - shape, dtype = v.shape, v.dtype - new_v = storage_slice.view(dtype) - if pad: - new_v = new_v[: v.numel()] - new_v = new_v.view(shape) - if set_on_tensor: - v.set_( - new_v.untyped_storage(), - storage_offset=new_v.storage_offset(), - stride=new_v.stride(), - size=new_v.size(), - ) - return - flat_dict[k] = new_v - - njts = {} - if num_threads > 1: - executor = ThreadPoolExecutor(num_threads) - r = [] - for i, (k, v) in enumerate(flat_dict.items()): - r.append( - executor.submit( - assign, - k=k, - v=v, - start=offsets[i], - stop=offsets[i + 1], - njts=njts, - ) - ) - if not return_early: - wait(r) - else: - # TODO: We'd need to merge the second half of this function to make this a thing - raise NotImplementedError( - "return_early is not implemented yet for `consolidate`." - ) - else: - for i, (k, v) in enumerate(flat_dict.items()): - assign( - k=k, - v=v, - start=offsets[i], - stop=offsets[i + 1], - njts=njts, - ) - if not set_on_tensor: - for njt_key, njt in njts.items(): - newkey = njt_key[:-1] + (njt_key[-1].replace("", ""),) - njt_key_values = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - njt_key_offset = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - njt_key_lengths = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - val = _rebuild_njt_from_njt( - njt, - values=flat_dict.pop(njt_key_values), - offsets=flat_dict.pop(njt_key_offset), - lengths=flat_dict.pop(njt_key_lengths, None), - ) - del flat_dict[njt_key] - flat_dict[newkey] = val - + raise NotImplementedError + else: if non_blocking and device.type != "cuda": # sync if needed - self._sync_all() - if set_on_tensor: - return self - else: + td._sync_all() + torch.cat(swaps, out=storage) + swaps = storage.split(lengths) + + result = [ + view_old_as_new( + v, + oldv, + # set_on_tensor=set_on_tensor) + ) + for (v, oldv) in zip(swaps, origs, strict=True) + ] - items = [] - for v in flat_dict.values(): - if v.is_nested: - items.append(None) - continue - if v.device != storage.device: - v = v.to(storage.device, non_blocking=non_blocking) - if is_dynamo_compiling(): - v = v.clone(memory_format=torch.contiguous_format) - else: - stride = v.stride() - if (stride and stride[-1] != 1) or v.storage_offset(): - v = v.clone(memory_format=torch.contiguous_format) - items.append(v) - - items = view_cat_split( - self, - items, - storage, - need_padding, - non_blocking, - device, - flat_size, - set_on_tensor, - ) if set_on_tensor: return self - for k, v in _zip_strict(list(flat_dict.keys()), items): - if not k[-1].startswith("<"): - flat_dict[k] = v - elif k[-1].startswith(""): - # NJT/NT always comes before offsets/shapes - nt = flat_dict[k] - assert not v.numel() - nt_lengths = None - del flat_dict[k] - elif k[-1].startswith(""): - nt_vaues = v - del flat_dict[k] - elif k[-1].startswith(""): - nt_lengths = v - del flat_dict[k] - elif k[-1].startswith(""): - newk = k[:-1] + (k[-1].replace("", ""),) - nt_offsets = v - del flat_dict[k] - - val = _rebuild_njt_from_njt( - nt, values=nt_vaues, offsets=nt_offsets, lengths=nt_lengths - ) - - flat_dict[newk] = val - - # delete the nested value to make sure that if there was an - # ordering mismatch we wouldn't be looking at the value key of - # another nested tensor. - del nt, nt_vaues, nt_offsets, nt_lengths - else: - flat_dict[k] = v - - def assign_val(key, val): - if isinstance(key, str): - key = (key,) - return flat_dict.get(key, val) - - if filename is None: - device = self.device - elif not inplace: - device = torch.device("cpu") - elif self.device is not None and self.device != torch.device("cpu"): - self.clear_device_() - device = None - else: - device = None - if inplace: - result = self + if filename is None: + device = self.device + elif not inplace: + device = torch.device("cpu") + elif self.device is not None and self.device != torch.device("cpu"): + self.clear_device_() + device = None + else: + device = None + if inplace: + out = self + elif device in (self.device, None): + out = self.copy() + else: + out = self._fast_apply(lambda x: x, device=device) + if metadata: + keys = metadata_dict["leaves"].keys() + else: + keys, _ = zip(*items) + for k, v in _zip_strict(keys, result): + if isinstance(k, str): + k = (k,) + out._set_tuple(k, v, validated=True, inplace=False) + if metadata: + out._consolidated = {"storage": storage} + out._consolidated["metadata"] = metadata_dict_or_values else: - result = None - result = self._fast_apply( - assign_val, - named=True, - nested_keys=True, - is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, - out=result, - device=device, - ) - result._consolidated = {"storage": storage, "metadata": metadata_dict} + out._consolidated = True + if filename is not None: if use_buffer: with open(filename, "w+b") as f: f.write(total_storage._handler.buffer) - # with open(Path(filename).with_suffix(".json"), "wb") as f: - # metadata_dict["size"] = filesize - # f.write(json.dumps(metadata_dict)) - return result + return out @classmethod def from_consolidated(cls, filename): @@ -4117,7 +3892,20 @@ def from_consolidated(cls, filename): def is_consolidated(self): """Checks if a TensorDict has a consolidated storage.""" - return hasattr(self, "_consolidated") + return getattr(self, "_consolidated", False) + + def consolidated_storage(self): + consolidated_data = getattr(self, "_consolidated", False) + if isinstance(consolidated_data, dict): + return consolidated_data["storage"] + elif consolidated_data: + for k, t in self.items(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS): + break + storage = t.untyped_storage() + return torch.empty((), dtype=torch.uint8, device=self.device).set_( + storage, storage_offset=0, stride=(1,), size=(len(storage),) + ) + return None def memmap_( self, @@ -5594,53 +5382,38 @@ def items( if is_leaf is None: is_leaf = _default_is_leaf - def _items(): - if include_nested and leaves_only: - # check the conditions once only - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if not is_leaf(type(val)): - yield from ( - (_unravel_key_to_tuple((k, _key)), _val) - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - ) - else: - yield k, val - elif include_nested: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - yield k, val - if not is_leaf(type(val)): - yield from ( - (_unravel_key_to_tuple((k, _key)), _val) - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - ) - elif leaves_only: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if is_leaf(type(val)): - yield k, val - else: - for k in self.keys(): - yield k, self._get_str(k, NO_DEFAULT) - if sort: yield from sorted( - _items(), + self.items(include_nested, leaves_only, is_leaf), key=lambda item: ( item[0] if isinstance(item[0], str) else ".".join(item[0]) ), ) + + if include_nested: + # check the conditions once only + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) + cls = type(val) + if not leaves_only or is_leaf(cls): + yield k, val + if _is_tensor_collection(cls): + yield from ( + (_unravel_key_to_tuple((k, _key)), _val) + for _key, _val in val.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) + ) + elif leaves_only: + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) + if is_leaf(type(val)): + yield k, val else: - yield from _items() + for k in self.keys(): + yield k, self._get_str(k, NO_DEFAULT) def non_tensor_items(self, include_nested: bool = False): """Returns all non-tensor leaves, maybe recursively.""" @@ -10631,11 +10404,17 @@ def copy_dict(d): if compilable: result = self._to_consolidated_compile( - device=device, num_threads=num_threads, storage_cast=storage_cast, _consolidated=_consolidated, + device=device, + num_threads=num_threads, + storage_cast=storage_cast, + _consolidated=_consolidated, ) else: result = self._to_consolidated_eager( - device=device, num_threads=num_threads, storage_cast=storage_cast, _consolidated=_consolidated, + device=device, + num_threads=num_threads, + storage_cast=storage_cast, + _consolidated=_consolidated, ) if non_blocking in (False, None): @@ -10652,7 +10431,9 @@ def copy_dict(d): return result - def _to_consolidated_eager(self, *, device, num_threads, storage_cast, _consolidated): + def _to_consolidated_eager( + self, *, device, num_threads, storage_cast, _consolidated + ): untyped_storage = storage_cast.untyped_storage() @@ -10715,7 +10496,9 @@ def set_(x): result._consolidated = _consolidated return result - def _to_consolidated_compile(self, *, device, num_threads, storage_cast, _consolidated): + def _to_consolidated_compile( + self, *, device, num_threads, storage_cast, _consolidated + ): def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()): root = False diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index cacdd6905..124503274 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -918,8 +918,9 @@ def _from_tensordict(cls, tensordict, non_tensordict=None, safe=True): # noqa: # empty tensordict and writing values to it. we can skip this because we already # have a tensordict to use as the underlying tensordict tc = cls.__new__(cls) - tc.__dict__.update({"_tensordict": tensordict, - "_non_tensordict": non_tensordict}) + tc.__dict__.update( + {"_tensordict": tensordict, "_non_tensordict": non_tensordict} + ) # since we aren't calling the dataclass init method, we need to manually check # whether a __post_init__ method has been defined and invoke it if so if hasattr(cls, "__post_init__"): @@ -1143,11 +1144,15 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 return wrapper -def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False, force_wrap=False): +def _wrap_td_method( + funcname, *, copy_non_tensor=False, no_wrap=False, force_wrap=False +): def deliver_result(self, result, kwargs): if result is None: return - if (force_wrap or isinstance(result, TensorDictBase)) and kwargs.get("out") is not result: + if (force_wrap or isinstance(result, TensorDictBase)) and kwargs.get( + "out" + ) is not result: if not is_dynamo_compiling(): non_tensordict = super(type(self), self).__getattribute__( "_non_tensordict" diff --git a/tensordict/utils.py b/tensordict/utils.py index 1cee55a2f..e0ec809ae 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2721,38 +2721,20 @@ def view_and_pad(tensor: torch.Tensor, need_padding: bool) -> torch.Tensor: return result -def view_old_as_new(v: torch.Tensor, oldv: torch.Tensor) -> torch.Tensor: +def view_old_as_new( + v: torch.Tensor, oldv: torch.Tensor, set_on_tensor=False +) -> torch.Tensor: + if set_on_tensor: + oldv.set_( + v.untyped_storage(), + storage_offset=v.storage_offset(), + stride=v.stride(), + size=oldv.size(), + ) + return oldv if oldv is None: return v v = v.view(oldv.dtype) if v.numel() > oldv.numel(): return v[: oldv.numel()].view(oldv.shape) return v.view(oldv.shape) - - -@torch.compiler.disable() -def view_cat_split( - td, items, storage, need_padding, non_blocking, device, flat_size, set_on_tensor -): - items_flat = [view_and_pad(v, need_padding) for v in items if v is not None] - if non_blocking and device.type != "cuda": - # sync if needed - td._sync_all() - torch.cat(items_flat, out=storage) - # TODO: breaks with NJT - result = [ - view_old_as_new(v, oldv) - for (v, oldv) in zip(storage.split(flat_size), items, strict=True) - ] - if set_on_tensor: - for t_dest, t_src in zip(result, items): - if t_src is None: - # njt is decomposed - continue - t_src.set_( - t_dest.untyped_storage(), - storage_offset=t_dest.storage_offset(), - stride=t_dest.stride(), - size=t_dest.size(), - ) - return result diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 0f1f65b5d..1dbcbcea1 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -430,7 +430,15 @@ def test_consolidate(self, device, use_file, tmpdir, num_threads, nested, hetdty ), td_c.to_dict() assert td_c["d"] == "a string!" - storage = td_c._consolidated["storage"] + storage = td_c.consolidated_storage() + print( + storage.untyped_storage().data_ptr(), + td_c["b", "c"].untyped_storage().data_ptr(), + ) + print( + storage.untyped_storage().data_ptr(), td_c["a"].untyped_storage().data_ptr() + ) + assert isinstance(storage, torch.Tensor) storage *= 0 if not nested: assert (td.to(td_c.device) != td_c).any(), td_c.to_dict()