diff --git a/znh5md/misc.py b/znh5md/misc.py index fd569b7d..f8e54524 100644 --- a/znh5md/misc.py +++ b/znh5md/misc.py @@ -24,21 +24,19 @@ def concatenate_varying_shape_arrays( """ # Determine the maximum shape along all dimensions - maxshape = list(values[0].shape) - for value in values[1:]: - maxshape = [max(a, b) for a, b in zip(maxshape, value.shape)] + shapes = np.array([value.shape for value in values]) + maxshape = tuple(np.max(shapes, axis=0)) # Add the batch dimension maxshape = (len(values), *maxshape) - # Create an array filled with the fillvalue + # Create the array filled with the fillvalue dataset = np.full(maxshape, fillvalue, dtype=dtype) - - # Insert each value into the dataset + + # Get the slices for each value and assign them all at once for i, value in enumerate(values): - # Create slices for each dimension of the current value - slices = tuple(slice(0, dim) for dim in value.shape) - dataset[(i,) + slices] = value + slices = (i,) + tuple(slice(0, dim) for dim in value.shape) + dataset[slices] = value return dataset