@@ -305,10 +305,18 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
305305                # An optimization when `batch` contains only one tensor: 
306306                # - produce exactly same result as `torch.stack(batch)` 
307307                # - will achieve zero-copy if the tensor is contiguous 
308-                 return  batch [0 ].unsqueeze (0 ).contiguous ()
308+                 # Replace original tensor so that its memory can be freed 
309+                 # in the non-contiguous case. 
310+                 batch [0 ] =  batch [0 ].contiguous ()
311+                 return  batch [0 ].unsqueeze (0 )
309312            first_shape  =  batch [0 ].shape 
310313            if  all (elem .shape  ==  first_shape  for  elem  in  batch ):
311-                 return  torch .stack (batch )
314+                 stack  =  torch .stack (batch )
315+                 # Replace original tensors with slices into the new one, 
316+                 # so that their memory can be freed. 
317+                 for  i  in  range (len (batch )):
318+                     batch [i ] =  stack [i ]
319+                 return  stack 
312320
313321        return  batch 
314322
@@ -337,10 +345,21 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
337345                # An optimization when `batch` contains only one tensor: 
338346                # - produce exactly same result as `torch.concat(batch)` 
339347                # - will achieve zero-copy if the tensor is contiguous 
340-                 return  batch [0 ].contiguous ()
341-             first_shape  =  batch [0 ].shape 
342-             if  all (elem .shape [1 :] ==  first_shape [1 :] for  elem  in  batch ):
343-                 return  torch .concat (batch )
348+                 # Replace original tensor so that its memory can be freed 
349+                 # in the non-contiguous case. 
350+                 batch [0 ] =  batch [0 ].contiguous ()
351+                 return  batch [0 ]
352+             first_shape  =  batch [0 ].shape [1 :]
353+             if  all (elem .shape [1 :] ==  first_shape  for  elem  in  batch ):
354+                 concat  =  torch .concat (batch )
355+                 # Replace original tensors with slices into the new one, 
356+                 # so that their memory can be freed. 
357+                 off  =  0 
358+                 for  i  in  range (len (batch )):
359+                     size  =  batch [i ].shape [0 ]
360+                     batch [i ] =  concat [off :off  +  size ]
361+                     off  +=  size 
362+                 return  concat 
344363
345364        return  [e  for  elem  in  batch  for  e  in  elem ]
346365
0 commit comments