Skip to content

Commit

Permalink
use factory for anemoi-dataset copy
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Mar 4, 2024
1 parent 3e4a00e commit aa767b4
Showing 1 changed file with 187 additions and 184 deletions.
371 changes: 187 additions & 184 deletions ecml_tools/commands/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,198 +39,201 @@
"""


class Create(Command):
internal = True
timestamp = True

def add_arguments(self, command_parser):
command_parser.add_argument("--source", required=True)
command_parser.add_argument("--target", required=True)
command_parser.add_argument("--transfers", type=int, default=8)
command_parser.add_argument("--block-size", type=int, default=100)
command_parser.add_argument("--overwrite", action="store_true")
command_parser.add_argument("--progress", action="store_true")
command_parser.add_argument(
"--rechunk",
nargs="+",
help="Rechunk given array.",
metavar="array=i,j,k,l",
)

def _store(self, path):
return path

def copy_chunk(self, n, m, source, target, block_size, _copy, progress):
if _copy[n:m].all():
LOG.info(f"Skipping {n} to {m}")
return None

for i in tqdm.tqdm(
range(n, m),
desc=f"Copying {n} to {m}",
leave=False,
disable=not isatty and not progress,
):
target[i] = source[i]
return slice(n, m)

def copy_data(self, source, target, transfers, block_size, _copy, progress, rechunking):
LOG.info("Copying data")
source_data = source["data"]

chunks = list(source_data.chunks)
if "data" in rechunking:
assert len(chunks) == len(rechunking["data"]), (chunks, rechunking["data"])
for i, c in enumerate(rechunking["data"]):
if c != -1:
chunks[i] = c

target_data = (
target["data"]
if "data" in target
else target.create_dataset(
"data",
shape=source_data.shape,
chunks=chunks,
dtype=source_data.dtype,
def class_builder(_cls):
class Create(_cls):
internal = True
timestamp = True

def add_arguments(self, command_parser):
command_parser.add_argument("--source", required=True)
command_parser.add_argument("--target", required=True)
command_parser.add_argument("--transfers", type=int, default=8)
command_parser.add_argument("--block-size", type=int, default=100)
command_parser.add_argument("--overwrite", action="store_true")
command_parser.add_argument("--progress", action="store_true")
command_parser.add_argument(
"--rechunk",
nargs="+",
help="Rechunk given array.",
metavar="array=i,j,k,l",
)
)

executor = ThreadPoolExecutor(max_workers=transfers)
tasks = []
n = 0
while n < target_data.shape[0]:
tasks.append(
executor.submit(
self.copy_chunk,
n,
min(n + block_size, target_data.shape[0]),
source_data,
target_data,
block_size,
_copy,
progress,

def _store(self, path):
return path

def copy_chunk(self, n, m, source, target, block_size, _copy, progress):
if _copy[n:m].all():
LOG.info(f"Skipping {n} to {m}")
return None

for i in tqdm.tqdm(
range(n, m),
desc=f"Copying {n} to {m}",
leave=False,
disable=not isatty and not progress,
):
target[i] = source[i]
return slice(n, m)

def copy_data(self, source, target, transfers, block_size, _copy, progress, rechunking):
LOG.info("Copying data")
source_data = source["data"]

chunks = list(source_data.chunks)
if "data" in rechunking:
assert len(chunks) == len(rechunking["data"]), (chunks, rechunking["data"])
for i, c in enumerate(rechunking["data"]):
if c != -1:
chunks[i] = c

target_data = (
target["data"]
if "data" in target
else target.create_dataset(
"data",
shape=source_data.shape,
chunks=chunks,
dtype=source_data.dtype,
)
)
n += block_size

for future in tqdm.tqdm(as_completed(tasks), total=len(tasks), smoothing=0):
copied = future.result()
if copied is not None:
_copy[copied] = True
target["_copy"] = _copy

target["_copy"] = _copy

LOG.info("Copied data")

def copy_array(self, name, source, target, transfers, block_size, _copy, progress, rechunking):
for k, v in source.attrs.items():
target.attrs[k] = v

if name == "_copy":
return

if name == "data":
self.copy_data(source, target, transfers, block_size, _copy, progress, rechunking)
return

LOG.info(f"Copying {name}")
target[name] = source[name]
LOG.info(f"Copied {name}")

def copy_group(self, source, target, transfers, block_size, _copy, progress, rechunking):
import zarr

for k, v in source.attrs.items():
target.attrs[k] = v

for name in sorted(source.keys()):
if isinstance(source[name], zarr.hierarchy.Group):
group = target[name] if name in target else target.create_group(name)
self.copy_group(
source[name],
group,
transfers,
block_size,
_copy,
progress,
rechunking,
)
else:
self.copy_array(
name,
source,
target,
transfers,
block_size,
_copy,
progress,
rechunking,

executor = ThreadPoolExecutor(max_workers=transfers)
tasks = []
n = 0
while n < target_data.shape[0]:
tasks.append(
executor.submit(
self.copy_chunk,
n,
min(n + block_size, target_data.shape[0]),
source_data,
target_data,
block_size,
_copy,
progress,
)
)
n += block_size

def copy(self, source, target, transfers, block_size, progress, rechunking):
import zarr
for future in tqdm.tqdm(as_completed(tasks), total=len(tasks), smoothing=0):
copied = future.result()
if copied is not None:
_copy[copied] = True
target["_copy"] = _copy

if "_copy" not in target:
target["_copy"] = zarr.zeros(
source["data"].shape[0],
dtype=bool,
)
_copy = target["_copy"]
_copy_np = _copy[:]

self.copy_group(source, target, transfers, block_size, _copy_np, progress, rechunking)
del target["_copy"]

def run(self, args):
import zarr

# base, ext = os.path.splitext(os.path.basename(args.source))
# assert ext == ".zarr", ext
# assert "." not in base, base
LOG.info(f"Copying {args.source} to {args.target}")

rechunking = {}
if args.rechunk:
for r in args.rechunk:
k, v = r.split("=")
if k != "data":
raise ValueError(f"Only rechunking data is supported: {k}")
values = v.split(",")
values = [-1 if x == "" else x for x in values]
values = tuple(int(x) for x in values)
rechunking[k] = values
for k, v in rechunking.items():
LOG.info(f"Rechunking {k} to {v}")

try:
target = zarr.open(self._store(args.target), mode="r")
if "_copy" in target:
done = sum(1 if x else 0 for x in target["_copy"])
todo = len(target["_copy"])
LOG.info(
"Resuming copy, done %s out or %s, %s%%",
done,
todo,
int(done / todo * 100 + 0.5),
)
elif "sums" in target and "data" in target: # sums is copied last
LOG.error("Target already exists")
target["_copy"] = _copy

LOG.info("Copied data")

def copy_array(self, name, source, target, transfers, block_size, _copy, progress, rechunking):
for k, v in source.attrs.items():
target.attrs[k] = v

if name == "_copy":
return

if name == "data":
self.copy_data(source, target, transfers, block_size, _copy, progress, rechunking)
return
except ValueError as e:
LOG.info(f"Target does not exist: {e}")
pass

source = zarr.open(self._store(args.source), mode="r")
if args.overwrite:
target = zarr.open(self._store(args.target), mode="w")
else:

LOG.info(f"Copying {name}")
target[name] = source[name]
LOG.info(f"Copied {name}")

def copy_group(self, source, target, transfers, block_size, _copy, progress, rechunking):
import zarr

for k, v in source.attrs.items():
target.attrs[k] = v

for name in sorted(source.keys()):
if isinstance(source[name], zarr.hierarchy.Group):
group = target[name] if name in target else target.create_group(name)
self.copy_group(
source[name],
group,
transfers,
block_size,
_copy,
progress,
rechunking,
)
else:
self.copy_array(
name,
source,
target,
transfers,
block_size,
_copy,
progress,
rechunking,
)

def copy(self, source, target, transfers, block_size, progress, rechunking):
import zarr

if "_copy" not in target:
target["_copy"] = zarr.zeros(
source["data"].shape[0],
dtype=bool,
)
_copy = target["_copy"]
_copy_np = _copy[:]

self.copy_group(source, target, transfers, block_size, _copy_np, progress, rechunking)
del target["_copy"]

def run(self, args):
import zarr

# base, ext = os.path.splitext(os.path.basename(args.source))
# assert ext == ".zarr", ext
# assert "." not in base, base
LOG.info(f"Copying {args.source} to {args.target}")

rechunking = {}
if args.rechunk:
for r in args.rechunk:
k, v = r.split("=")
if k != "data":
raise ValueError(f"Only rechunking data is supported: {k}")
values = v.split(",")
values = [-1 if x == "" else x for x in values]
values = tuple(int(x) for x in values)
rechunking[k] = values
for k, v in rechunking.items():
LOG.info(f"Rechunking {k} to {v}")

try:
target = zarr.open(self._store(args.target), mode="w+")
except ValueError:
target = zarr.open(self._store(args.target), mode="r")
if "_copy" in target:
done = sum(1 if x else 0 for x in target["_copy"])
todo = len(target["_copy"])
LOG.info(
"Resuming copy, done %s out or %s, %s%%",
done,
todo,
int(done / todo * 100 + 0.5),
)
elif "sums" in target and "data" in target: # sums is copied last
LOG.error("Target already exists")
return
except ValueError as e:
LOG.info(f"Target does not exist: {e}")
pass

source = zarr.open(self._store(args.source), mode="r")
if args.overwrite:
target = zarr.open(self._store(args.target), mode="w")
self.copy(source, target, args.transfers, args.block_size, args.progress, rechunking)
else:
try:
target = zarr.open(self._store(args.target), mode="w+")
except ValueError:
target = zarr.open(self._store(args.target), mode="w")
self.copy(source, target, args.transfers, args.block_size, args.progress, rechunking)

return Create


command = Create
command = class_builder(Command)

0 comments on commit aa767b4

Please sign in to comment.