Skip to content

xr.concat has over 3x the peak memory usage and 5x slower than np.concatenate, even with large chunk sizes #10864

@mjwillson

Description

@mjwillson

What is your issue?

I guess this isn't technically a bug but it is quite poor performance. I'd be curious to understand why, and if there's any workaround, especially for the peak memory which is causing OOMs for us in some jobs.

Here I am concatenating very simple DataArrays with no coordinates and large (100MiB) chunks. I understand xarray will have some overhead relative to numpy, but I would expect this overhead to be largely amortised away when dealing with chunks of data this large.

xr.concat has peak memory usage of around 6x the size of the inputs, and this seems fairly consistent over a range of chunk sizes and numbers of chunks.

import xarray as xr
import numpy as np
import tracemalloc
import gc

tracemalloc.start()

# 10 x 100MiB arrays
data_arrays = [
    xr.DataArray(data=np.zeros(25 * 1024*1024, dtype=np.float32), dims=['x'])
    for _ in range(10)
]

print("After allocating input arrays:")
current, peak = tracemalloc.get_traced_memory()
print(f"  Current memory usage: {current/1024/1024:g} MiB.")
print(f"  Peak memory usage: {peak/1024/1024:g} MiB.")
tracemalloc.reset_peak()

print("Timing xr.concat:")
%time result = xr.concat(data_arrays, dim='x')

print("After xr.concat:")
current, peak = tracemalloc.get_traced_memory()
print(f"  Current memory usage: {current/1024/1024:g} MiB.")
print(f"  Peak memory usage: {peak/1024/1024:g} MiB.")

del result
del data_arrays

print("After dropping references to input and result:")
current, peak = tracemalloc.get_traced_memory()
print(f"  Current memory usage: {current/1024/1024:g} MiB.")

gc.collect()
print("After gc.collect():")
current, peak = tracemalloc.get_traced_memory()
print(f"  Current memory usage: {current/1024/1024:g} MiB.")

tracemalloc.stop()

Output:

After allocating input arrays:
Current memory usage: 1000.03 MiB.
Peak memory usage: 1000.05 MiB.
Timing xr.concat:
CPU times: user 1.02 s, sys: 1.23 s, total: 2.26 s
Wall time: 2.23 s
After xr.concat:
Current memory usage: 6250.08 MiB.
Peak memory usage: 6250.1 MiB.
After dropping references to input and result:
Current memory usage: 6250.08 MiB.
After gc.collect():
Current memory usage: 0.0356817 MiB.

Equivalent result for np.concatenate. It has peak memory usage of 2x the size of the inputs, which is what I would expect (size of inputs + size of buffer for the output), and is also a lot faster:

tracemalloc.start()

print("At start:")
current, peak = tracemalloc.get_traced_memory()
print(f"  Current memory usage: {current/1024/1024:g} MiB.")

# 10 x 100MiB arrays
arrays = [
    np.zeros(25 * 1024*1024, dtype=np.float32)
    for _ in range(10)
]

print("After allocating input arrays:")
current, peak = tracemalloc.get_traced_memory()
print(f"  Current memory usage: {current/1024/1024:g} MiB.")
tracemalloc.reset_peak()

print("Timing np.concatenate:")
%time result = np.concatenate(arrays, axis=0)

print("After np.concatenate:")
current, peak = tracemalloc.get_traced_memory()
print(f"  Current memory usage: {current/1024/1024:g} MiB.")
print(f"  Peak memory usage: {peak/1024/1024:g} MiB.")

del result
del arrays

print("After dropping references to input and result:")
current, peak = tracemalloc.get_traced_memory()
print(f"  Current memory usage: {current/1024/1024:g} MiB.")

gc.collect()
print("After gc.collect():")
current, peak = tracemalloc.get_traced_memory()
print(f"  Current memory usage: {current/1024/1024:g} MiB.")

tracemalloc.stop()

Outputs:

At start:
Current memory usage: 0.00126362 MiB.
After allocating input arrays:
Current memory usage: 1000.02 MiB.
Timing np.concatenate:
CPU times: user 231 ms, sys: 254 ms, total: 485 ms
Wall time: 472 ms
After np.concatenate:
Current memory usage: 2000.03 MiB.
Peak memory usage: 2000.07 MiB.
After dropping references to input and result:
Current memory usage: 1000.04 MiB.
After gc.collect():
Current memory usage: 0.0444469 MiB.

I suspect the high peak memory may be due in part to cyclic references, since it sticks around even after the reference to the result has been dropped, but then disappears after a gc.collect(). It's possible if we were under memory pressure a GC would be triggered sooner reducing the peak usage, but this hard to verify, and we have seen OOMs in practice from xr.concat. I'm also not sure why any large intermediate results would be needed in the main computation here besides passing the data into np.concatenate.

As for the poor performance (2.26s vs 485ms here), I really have no idea, especially given there are no coordinates or indexes on these DataArrays that would require any special treatment, and the chunk sizes are large so I'd expect most of the compute to be done by numpy.

Note for some dtypes the overhead is even worse: for np.int8 I found xr.concat used 19x the size of the inputs in peak memory usage and was 15x - 20x slower than numpy. Perhaps some unnecessary dtype conversions were going on there.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions