Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Algorithm for multi-stage rechunking #89

Merged
merged 13 commits into from
Nov 22, 2022

Conversation

shoyer
Copy link
Collaborator

@shoyer shoyer commented May 23, 2021

Background

Rechunker's current "Push / Pull Consolidated" algorithm can be thought of as a mix of "split" and "combine" steps:

  • Combine source_chunks -> read_chunks (to fill up max_mem per chunk)
  • Split/Combine if read_chunks != write_chunks:
    • Split read_chunks -> int_chunks
    • Combine int_chunks -> write_chunks
  • Split write_chunks -> target_chunks

This is pretty clever, but can run into scalability issues. In particular, sometimes int_chunks must be very small, which results in a significant overhead for reading/writing small files (#80).

Proposal

I think the right way to fix this is to extend Rechunker's algorithm to allow for multiple split/combine stages -- as many as necessary, to avoid creating tiny intermediate chunk sizes. This PR implements the math for such as algorithm, in a fully backwards compatible fashion. Users can control the number of stages via the new min_mem parameter, which specifies a minimum chunk size in bytes.

Multi-stage rechunking is not yet hooked up to any of Rechunker's executors. I'm proposing adding it to rechunking because it reuses/modifies the existing code, and I need the math for rechunking inside Xarray-Beam (this was an easier way to explore writing the Beam executor part). Unfortunately it isn't easy to add into Rechunker's existing model, which stores intermediate results in Zarr, because multi-stage rechunking (at least this version) requires irregular chunks.

Dask also does multi-stage rechunking, which brought significant efficiency gains (dask/dask#417). I considered copying Dask's rechunk planning algorithm here, but it involves a lot of complex logic so I decided to try replacing it with a simple heuristic instead. See algorithm.py for details.

Example

My specific motivation for this PR is experimenting with rechunking on Pangeo's ERA5 single-level dataset, which contains a number of 1.5 TB variables (at least once decoded into float32). Rechunking these arrays with shape=(350640, 721, 1440) from "whole image" chunks (31, 721, 1440) to "whole time-series" chunks (350640, 10, 10) with Rechunker's current algorithm produces a very large number of small chunks. It works, but seems much slower than it should be.

I wrote a little Python script to compare Rechunker's current method (min_mem=0), with my proposed multi-stage method (min_mem=10MB) and Dask's rechunking method:

from rechunker import algorithm
import numpy as np
import sys
import math
from rechunker.compat import prod

def evaluate_stage_v2(shape, read_chunks, int_chunks, write_chunks):
  tasks = algorithm.calculate_single_stage_io_ops(shape, read_chunks, write_chunks)
  read_tasks = tasks if write_chunks != read_chunks else 0
  write_tasks = tasks if read_chunks != int_chunks else 0
  return read_tasks, write_tasks

def evaluate_plan(stages, shape, itemsize):
  total_reads = 0
  total_writes = 0
  for i, stage in enumerate(stages):
    read_chunks, int_chunks, write_chunks = stage
    read_tasks, write_tasks = evaluate_stage_v2(
        shape, read_chunks, int_chunks, write_chunks,
    )
    total_reads += read_tasks
    total_writes += write_tasks
  return total_reads, total_writes

def print_summary(stages, shape, itemsize):
  for i, stage in enumerate(stages):
    print(f"stage={i}: " + " -> ".join(map(str, stage)))
    read_chunks, int_chunks, write_chunks = stage
    read_tasks, write_tasks = evaluate_stage_v2(
        shape, read_chunks, int_chunks, write_chunks,
    )
    print(f"  Tasks: {read_tasks} reads, {write_tasks} writes")
    print(f"  Split chunks: {itemsize*np.prod(int_chunks)/1e6 :1.3f} MB")

  total_reads, total_writes = evaluate_plan(stages, shape, itemsize)
  print("Overall:")
  print(f'  Reads count: {total_reads:1.3e}')
  print(f'  Write count: {total_writes:1.3e}')


# dask.array.rechunk is the function
rechunk_module = sys.modules['dask.array.rechunk']

def dask_plan(shape, source_chunks, target_chunks, threshold=None):
  source_expanded = rechunk_module.normalize_chunks(source_chunks, shape)
  target_expanded = rechunk_module.normalize_chunks(target_chunks, shape)
  # Note: itemsize seems to be ignored, by default
  stages = rechunk_module.plan_rechunk(
      source_expanded, target_expanded, threshold=threshold, itemsize=4,
  )
  write_chunks = [tuple(s[0] for s in stage) for stage in stages]
  read_chunks = [source_chunks] + write_chunks[:-1]
  int_chunks = [algorithm._calculate_shared_chunks(r, w) 
                for r, w in zip(write_chunks, read_chunks)]
  return list(zip(read_chunks, int_chunks, write_chunks))

def rechunker_plan(shape, source_chunks, target_chunks, **kwargs):
  stages = algorithm.multistage_rechunking_plan(
      shape, source_chunks, target_chunks, **kwargs
  )
  return (
      [(source_chunks, source_chunks, stages[0][0])]
      + list(stages)
      + [(stages[-1][-1], target_chunks, target_chunks)]
  )


itemsize = 4
shape = (350640, 721, 1440)
source_chunks = (31, 721, 1440)
target_chunks = (350640, 10, 10)

print(f'Total size: {itemsize*np.prod(shape)/1e12:.3} TB')
print(f'Source chunk count: {np.prod(shape)/np.prod(source_chunks):1.3e}')
print(f'Target chunk count: {np.prod(shape)/np.prod(target_chunks):1.3e}')

print()
print("Rechunker plan (min_mem=0, max_mem=500 MB):")
plan = rechunker_plan(
    shape, source_chunks, target_chunks, itemsize=4, min_mem=0, max_mem=int(500e6)
)
print_summary(plan, shape, itemsize=4)

print()
print("Rechunker plan (min_mem=10 MB, max_mem=500 MB):")
plan = rechunker_plan(
    shape, source_chunks, target_chunks, itemsize=4, min_mem=int(10e6), max_mem=int(500e6)
)
print_summary(plan, shape, itemsize=4)

print()
print("Dask plan (default):")
plan = dask_plan(shape, source_chunks, target_chunks)
print_summary(plan, shape, itemsize=4)
Total size: 1.46 TB
Source chunk count: 1.131e+04
Target chunk count: 1.038e+04

Rechunker plan (min_mem=0, max_mem=500 MB):
stage=0: (31, 721, 1440) -> (31, 721, 1440) -> (93, 721, 1440)
  Tasks: 11311 reads, 0 writes
  Split chunks: 128.742 MB
stage=1: (93, 721, 1440) -> (93, 10, 30) -> (350640, 10, 30)
  Tasks: 13213584 reads, 13213584 writes
  Split chunks: 0.112 MB
stage=2: (350640, 10, 30) -> (350640, 10, 10) -> (350640, 10, 10)
  Tasks: 10512 reads, 10512 writes
  Split chunks: 140.256 MB
Overall:
  Reads count: 1.324e+07
  Write count: 1.322e+07

Rechunker plan (min_mem=10 MB, max_mem=500 MB):
stage=0: (31, 721, 1440) -> (31, 721, 1440) -> (93, 721, 1440)
  Tasks: 11311 reads, 0 writes
  Split chunks: 128.742 MB
stage=1: (93, 721, 1440) -> (93, 173, 396) -> (1447, 173, 396)
  Tasks: 80220 reads, 80220 writes
  Split chunks: 25.485 MB
stage=2: (1447, 173, 396) -> (1447, 41, 109) -> (22528, 41, 109)
  Tasks: 96492 reads, 96492 writes
  Split chunks: 25.867 MB
stage=3: (22528, 41, 109) -> (22528, 10, 30) -> (350640, 10, 30)
  Tasks: 86864 reads, 86864 writes
  Split chunks: 27.034 MB
stage=4: (350640, 10, 30) -> (350640, 10, 10) -> (350640, 10, 10)
  Tasks: 10512 reads, 10512 writes
  Split chunks: 140.256 MB
Overall:
  Reads count: 2.854e+05
  Write count: 2.741e+05

Dask plan (default):
stage=0: (31, 721, 1440) -> (31, 721, 1440) -> (32, 721, 1440)
  Tasks: 21915 reads, 0 writes
  Split chunks: 128.742 MB
stage=1: (32, 721, 1440) -> (32, 160, 480) -> (302, 160, 480)
  Tasks: 180705 reads, 180705 writes
  Split chunks: 9.830 MB
stage=2: (302, 160, 480) -> (302, 80, 150) -> (2922, 80, 150)
  Tasks: 153720 reads, 153720 writes
  Split chunks: 14.496 MB
stage=3: (2922, 80, 150) -> (2922, 20, 60) -> (29220, 20, 60)
  Tasks: 128760 reads, 128760 writes
  Split chunks: 14.026 MB
stage=4: (29220, 20, 60) -> (29220, 10, 10) -> (350640, 10, 10)
  Tasks: 126144 reads, 126144 writes
  Split chunks: 11.688 MB
Overall:
  Reads count: 6.112e+05
  Write count: 5.893e+05

Comparing my new multi-stage algorithm (max_mem=10MB) to Rechunker's existing algorithm (max_mem=0), ​the multi-stage pipeline does two extra dataset copies, but reduces the number of IO operations by ~50x.

Comparing my new algorithm to Dask's algorithm, the plans actually look remarkably similar. My estimates suggest that my algorithm should involve about half the number of IO operations, but Dask's plan uses slightly "nicer" chunk sizes. I have no idea which is better is practice, and note that I'm using Dask's algorithm without adjusting any of the control knobs.

I have not yet benchmarked any of these algorithms on real rechunking tasks. See below for benchmarking results from Xarray-Beam, for which it significantly improves real-world performance.

@codecov
Copy link

codecov bot commented May 23, 2021

Codecov Report

Merging #89 (847766f) into master (9d12932) will increase coverage by 0.32%.
The diff coverage is 97.01%.

@@            Coverage Diff             @@
##           master      #89      +/-   ##
==========================================
+ Coverage   96.02%   96.35%   +0.32%     
==========================================
  Files          11       11              
  Lines         503      548      +45     
  Branches      112      105       -7     
==========================================
+ Hits          483      528      +45     
  Misses         13       13              
  Partials        7        7              
Impacted Files Coverage Δ
rechunker/api.py 97.12% <75.00%> (-0.92%) ⬇️
rechunker/algorithm.py 91.22% <100.00%> (+5.51%) ⬆️
rechunker/compat.py 100.00% <100.00%> (+22.22%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

Comment on lines 171 to 174
if max_mem < min_mem: # basic sanity test
raise ValueError(
"max_mem ({max_mem}) cannot be smaller than min_mem ({min_mem})"
)
Copy link
Collaborator Author

@shoyer shoyer May 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might consider making this more strict, e.g., requiring max_mem < 2 * min_mem. If min_mem is only slightly smaller than max_mem, the algorithm will exit before achieving the min_mem objective due to increasing IO op count. For example, I cannot exceed min_mem=350MB if max_mem=500MB in my ERA5 example.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mean requiring min_mem < 2 * max_mem. I would support adding that constraint.

rechunker/algorithm.py Outdated Show resolved Hide resolved
@shoyer shoyer requested a review from rabernat May 23, 2021 05:17
@shoyer
Copy link
Collaborator Author

shoyer commented May 23, 2021

Thinking about this a little bit more, this may be a harder problem for the rechunker data model than I thought. The problem is that my algorithm does not guarantee that intermediate chunks can be written to Zarr.

For example, consider stage 1 in my proposed rechunking plan above: (93, 721, 1440) -> (93, 173, 396) -> (1447, 173, 396). 1447 is not a multiple of 93 -- they have a greatest common divisor of 1 -- so even though the typical chunk along the first axes in the intermediates would be 93, in the worst case we'd have to save a single element along that axis.

In Xarray-Beam and Dask, this is handled by producing temporary irregular chunks from the "split" step. This is fine in principle, but won't work for the current version of Rechunker, which stores intermediates in Zarr.

I see a few possible ways to resolve this:

  1. Extend Zarr to support irregular chunks
  2. Switch Rechunker to store irregular chunks in raw Zarr stores (or a filesystem) rather than as Zarr arrays
  3. Figure out another multi-stage algorithm that deals with these "regular multiple" constraints (this seems hard).

@rabernat
Copy link
Member

Thanks a lot for this Stephan! It will take me a couple of days to digest this.

copybara-service bot pushed a commit to google/xarray-beam that referenced this pull request May 24, 2021
This aligns the notion of rechunk "stage" more closely with a stage in multi-stage
rechunking: pangeo-data/rechunker#89
PiperOrigin-RevId: 375538957
copybara-service bot pushed a commit to google/xarray-beam that referenced this pull request May 24, 2021
This aligns the notion of rechunk "stage" more closely with a stage in multi-stage
rechunking: pangeo-data/rechunker#89
PiperOrigin-RevId: 375538957
copybara-service bot pushed a commit to google/xarray-beam that referenced this pull request May 24, 2021
This aligns the notion of rechunk "stage" more closely with a stage in multi-stage
rechunking: pangeo-data/rechunker#89
PiperOrigin-RevId: 375553383
@shoyer
Copy link
Collaborator Author

shoyer commented May 25, 2021

I tested this in Xarray-Beam today, rechunking a single 1.3 TB float32 array from (31, 721, 1440) -> (350640, 10, 10) with max_mem=256e6. This is large enough that Rechunker doesn't do any additional consolidation of reads/writes:

  • My single stage pipeline (the default from rechunker) ran in 154 minutes on a few hundred workers, with a minimum intermediate chunk size of 12.4 KB.
  • My two stage pipeline ran in 42 minutes on the same number of workers, with a minimum intermediate chunk size of 1.25 MB. The overall cost of the pipeline (including IO costs) was 68% lower.
  • The three stage pipeline (minimum chunk of size of 5.88 MB) was slightly worse than two stage pipeline (57% cost savings).

Edit: you can find the example code for this over in
google/xarray-beam#14

@shoyer
Copy link
Collaborator Author

shoyer commented Nov 19, 2022

I am happy to rebase, but would anyone be up for reviewing this PR?

The alternative would be to fork this code into Xarray-Beam.

@shoyer
Copy link
Collaborator Author

shoyer commented Nov 21, 2022

@rabernat any thoughts here?

@rabernat
Copy link
Member

HI Stephan! Sorry for letting this lie dormant for so long. Will definitely take a look this week. Thanks for your patience.

@shoyer
Copy link
Collaborator Author

shoyer commented Nov 21, 2022

HI Stephan! Sorry for letting this lie dormant for so long. Will definitely take a look this week. Thanks for your patience.

Thanks!

I just merged in "master" so this should be ready for review.

Copy link
Member

@rabernat rabernat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for this PR Stephan! It looks incredibly clever, useful, and well tested.

I spent about an hour reviewing this PR. Maybe because it's late in the day, I was not able to fully understand everything that is going on at the algorithm level. In particular, I am not quite comprehending the what gives rise to the issue of irregular intermediate chunks, which I believe is linked to the least-common multiple calculation. My understanding is that the lcm calculation is only needed in calculate_single_stage_io_ops in order to figure out the number of i/o ops required in each stage. (The reason we didn't need this logic before was that we simply didn't care how many i/o ops were used to achieve the rechunking.) Source / target chunk pairs with small least-common multiples can be copied with fewer i/o ops. Is that a correct interpretation?

So my main suggestion at this stage would be to to inject a few more choice comments into the code which will help other developers recreated the mental model you're using to reason about this stuff. This would make the code more maintainable by others.

Other than that, it all looks great, and all tests pass. 🎉

Going forward, we should explore the irregular chunks approach. You should check out https://zarr.dev/zeps/draft/ZEP0003.html, which proposes supporting this at the Zarr level.

# Add a small floating-point epsilon so we don't inadvertently
# round-down even chunk-sizes.
chunks = tuple(
floor(rc ** (1 - power) * wc**power + epsilon)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love a comment explaining the math. That would help make the code more maintainable by others.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually,np.geomspace implements exactly what I was trying to do here. So I'll just use that instead!



def _count_num_splits(source_chunk: int, target_chunk: int, size: int) -> int:
multiple = lcm(source_chunk, target_chunk)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help explain why the least-common multiple calculation is needed here? We managed to avoid these types of constraints (and the edge cases that come with them) in the original algorithm, and I'm not quite grokking why it's needed here. Clearly this is related to the irregular chunk problem.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just about counting how many operations would be required for a rechunk operation. I wrote a docstring with an example to help clarify.

Comment on lines 171 to 174
if max_mem < min_mem: # basic sanity test
raise ValueError(
"max_mem ({max_mem}) cannot be smaller than min_mem ({min_mem})"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mean requiring min_mem < 2 * max_mem. I would support adding that constraint.

Comment on lines +288 to +299
raise AssertionError(
"Failed to find a feasible multi-staging rechunking scheme satisfying "
f"min_mem ({min_mem}) and max_mem ({max_mem}) constraints. "
"Please file a bug report on GitHub: "
"https://github.com/pangeo-data/rechunker/issues\n\n"
"Include the following debugging info:\n"
f"shape={shape}, source_chunks={source_chunks}, "
f"target_chunks={target_chunks}, itemsize={itemsize}, "
f"min_mem={min_mem}, max_mem={max_mem}, "
f"consolidate_reads={consolidate_reads}, "
f"consolidate_writes={consolidate_writes}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great error message!

new_chunks : tuple
The new chunks, size guaranteed to be <= mam_mem
"""
(stage,) = multistage_rechunking_plan(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really appreciate how this was implemented with clean backwards compatibility.

((100,), (43,), (51,), 4),
((100,), (43,), (40,), 5),
((100,), (43,), (10,), 12),
((100,), (43,), (1,), 100),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the impression that there is a lot of intuition and understanding encoded into the above test cases. It would be nice to get more of that into comments to help other devs understand the logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some brief comments :)

def calculate_single_stage_io_ops(
shape: Sequence[int], in_chunks: Sequence[int], out_chunks: Sequence[int]
) -> int:
"""Estimate the number of irregular chunks required for rechunking."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot reconcile the docstring with the function name. Where are "irregular chunks" being calculated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the note about irregular chunks -- that's really unrelated.

@shoyer
Copy link
Collaborator Author

shoyer commented Nov 22, 2022

I In particular, I am not quite comprehending the what gives rise to the issue of irregular intermediate chunks, which I believe is linked to the least-common multiple calculation.

This is a good point! The code as I wrote it was rather misleading -- the LCM logic is indeed just about counting IO ops.

Irregular chunks only become necessary because of the very simple algorithm (geometric spacing) I use for picking intermediate chunk sizes in multi-stage plans. With some cleverness, this could likely be avoided in many cases.

The current rechunker algorithm does actually support "irregular chunks" in the form of overlapping reads of source arrays. So it's certainly possible to use this limited form of irregular chunking to support arbitrary rechunks, and I suspect with effort we could make it work for multi-stage rechunks, too. The most obvious way to do this would be to simply pick some intermediate chunk size, and then run the existing rechunker algorithm recursively twice, for three total temporary Zarr arrays (or 7 or 15). This would be quite a bit less flexible than what I implemented here, though.

@rabernat
Copy link
Member

Fantastic. These comments are a big help. Thanks for this contribution.

@rabernat rabernat merged commit 46db807 into pangeo-data:master Nov 22, 2022
copybara-service bot pushed a commit to google/xarray-beam that referenced this pull request Mar 21, 2023
As described in pangeo-data/rechunker#89,
this can yield significant performance benefits for rechunking large
arrays.

PiperOrigin-RevId: 375836954
copybara-service bot pushed a commit to google/xarray-beam that referenced this pull request Mar 21, 2023
As described in pangeo-data/rechunker#89,
this can yield significant performance benefits for rechunking large
arrays.

PiperOrigin-RevId: 375836954
copybara-service bot pushed a commit to google/xarray-beam that referenced this pull request Mar 21, 2023
As described in pangeo-data/rechunker#89,
this can yield significant performance benefits for rechunking large
arrays.

PiperOrigin-RevId: 375836954
copybara-service bot pushed a commit to google/xarray-beam that referenced this pull request Mar 21, 2023
As described in pangeo-data/rechunker#89,
this can yield significant performance benefits for rechunking large
arrays.

PiperOrigin-RevId: 375836954
copybara-service bot pushed a commit to google/xarray-beam that referenced this pull request Mar 21, 2023
As described in pangeo-data/rechunker#89,
this can yield significant performance benefits for rechunking large
arrays.

PiperOrigin-RevId: 375836954
copybara-service bot pushed a commit to google/xarray-beam that referenced this pull request Mar 21, 2023
As described in pangeo-data/rechunker#89,
this can yield significant performance benefits for rechunking large
arrays.

PiperOrigin-RevId: 375836954
copybara-service bot pushed a commit to google/xarray-beam that referenced this pull request Mar 21, 2023
As described in pangeo-data/rechunker#89,
this can yield significant performance benefits for rechunking large
arrays.

PiperOrigin-RevId: 518325665
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants