Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset Versioning #248

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions daskms/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import OrderedDict

import dask
from daskms.constants import DASKMS_METADATA
import dask.array as da
from dask.highlevelgraph import HighLevelGraph
import numpy as np
Expand Down Expand Up @@ -103,7 +104,7 @@ def _convert_to_variable(k, v):


if xr is not None:
from xarray import Dataset, Variable
from xarray import Dataset as BaseDataset, Variable
else:
# This class duplicates xarray's Frozen class in
# https://github.com/pydata/xarray/blob/master/xarray/core/utils.py
Expand Down Expand Up @@ -235,7 +236,7 @@ def __dask_postpersist__(self):
args = (fn, args, self.data.name, self.dims, self.attrs)
return (self.finalize_persist, args)

class Dataset:
class BaseDataset:
"""
Replicates a minimal subset of `xarray Dataset
<http://xarray.pydata.org/en/stable/generated/xarray.Dataset.html#xarray.Dataset>`_'s
Expand Down Expand Up @@ -531,3 +532,9 @@ def __dask_postpersist__(self):
for k, v in self._data_vars.items()
]
return self.finalize_persist, (data_info, self._coords, self._attrs)


class Dataset(BaseDataset):
def __init__(self, data_vars, coords=None, attrs=None):
attrs = {DASKMS_METADATA: {}, **(attrs or {})}
super().__init__(data_vars, coords=coords, attrs=attrs)
1 change: 1 addition & 0 deletions daskms/table_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
("nrows", READLOCK),
("colnames", READLOCK),
("getcoldesc", READLOCK),
("getdesc", READLOCK),
("getdminfo", READLOCK),
("iswritable", READLOCK),
# Modification
Expand Down
21 changes: 21 additions & 0 deletions daskms/tests/test_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import dask
import pyrap.tables as pt
import pytest

from daskms.constants import DASKMS_METADATA
from daskms import xds_from_storage_ms, xds_to_table


@pytest.mark.xfail
def test_provenance(ms, tmp_path_factory):
datasets = xds_from_storage_ms(ms)

for ds in datasets:
assert ds.attrs[DASKMS_METADATA]["provenance"] == [ms]

data_dir = tmp_path_factory.mktemp("provenance")
store = str(data_dir / "blah.ms")
dask.compute(xds_to_table(datasets, store))

with pt.table(str(store), ack=False) as T:
assert T.getkeywords()[DASKMS_METADATA] == {"provenance": [ms, store]}
11 changes: 9 additions & 2 deletions daskms/tests/test_ms_read_and_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,14 @@ def test_ms_update(ms, group_cols, index_cols, select_cols):
for k, _ in nds.attrs[DASKMS_METADATA][DASKMS_PARTITION_KEY]:
assert getattr(write, k) == getattr(nds, k)

# assert ds.attrs[DASKMS_METADATA]["provenance"] == [ms]

writes.append(write)

# Do all writes in parallel
dask.compute(writes)

xds = xds_from_ms(
rxds = xds_from_ms(
ms,
columns=select_cols,
group_cols=group_cols,
Expand All @@ -160,11 +162,16 @@ def test_ms_update(ms, group_cols, index_cols, select_cols):
)

# Check that state and data have been correctly written
it = enumerate(zip(xds, written_states, written_data))
it = enumerate(zip(rxds, written_states, written_data))
for i, (ds, state, data) in it:
assert_array_equal(ds.STATE_ID.data, state)
assert_array_equal(ds.DATA.data, data)

orig_part_key = xds[i].attrs[DASKMS_METADATA][DASKMS_PARTITION_KEY]
assert ds.attrs[DASKMS_METADATA][DASKMS_PARTITION_KEY] == orig_part_key
# assert ds.attrs[DASKMS_METADATA]["provenance"] == [ms]
assert len(ds.attrs[DASKMS_METADATA][DASKMS_PARTITION_KEY]) == len(group_cols)


@pytest.mark.parametrize(
"index_cols",
Expand Down
32 changes: 28 additions & 4 deletions daskms/writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pyrap.tables as pt

from daskms.columns import dim_extents_array
from daskms.constants import DASKMS_PARTITION_KEY, DASKMS_METADATA
from daskms.constants import DASKMS_METADATA, DASKMS_PARTITION_KEY
from daskms.dataset import Dataset
from daskms.dataset_schema import DatasetSchema
from daskms.descriptors.builder import AbstractDescriptorBuilder
Expand All @@ -20,7 +20,7 @@
from daskms.table import table_exists
from daskms.table_executor import executor_key
from daskms.table_proxy import TableProxy, WRITELOCK
from daskms.utils import table_path_split
from daskms.utils import table_path_split, freeze


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -566,6 +566,30 @@ def _write_datasets(
table_name = "::".join((table_name, subtable)) if subtable else table_name
row_orders = []

# frozen_meta = set(freeze(ds.attrs.get(DASKMS_METADATA, {})) for ds in datasets)

# if len(frozen_meta) == 0:
# metadata = {}
# elif len(frozen_meta) == 1:
# metadata = datasets[0].attrs.get(DASKMS_METADATA, {})
# else:
# raise ValueError(f"{DASKMS_METADATA} is not consistent across datasets")

# import json

# table_keywords = table_keywords or {}
# table_metadata = table_keywords.get(DASKMS_METADATA, {})
# table_keywords[DASKMS_METADATA] = {**metadata, **table_metadata}
# provenance = table_keywords[DASKMS_METADATA].setdefault("provenance", [])
# table_keywords[DASKMS_METADATA] = json.dumps(table_keywords[DASKMS_METADATA])

# try:
# provenance.remove(table)
# except ValueError:
# pass

# provenance.append(table)

# Put table and column keywords
table_proxy.submit(
_put_keywords, WRITELOCK, table_keywords, column_keywords
Expand Down Expand Up @@ -714,14 +738,14 @@ def _write_datasets(


def _put_keywords(table, table_keywords, column_keywords):
if table_keywords is not None:
if table_keywords:
for k, v in table_keywords.items():
if v == DELKW:
table.removekeyword(k)
else:
table.putkeyword(k, v)

if column_keywords is not None:
if column_keywords:
for column, keywords in column_keywords.items():
for k, v in keywords.items():
if v == DELKW:
Expand Down