Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3430 support dataframes and streams in CSVDataset #3440

Merged
merged 18 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing
from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform
from monai.utils import MAX_SEED, ensure_tuple, get_seed, look_up_option, min_version, optional_import
from monai.utils import MAX_SEED, deprecated_arg, get_seed, look_up_option, min_version, optional_import
from monai.utils.misc import first

if TYPE_CHECKING:
Expand Down Expand Up @@ -1222,8 +1222,9 @@ class CSVDataset(Dataset):
]

Args:
filename: the filename of expected CSV file to load. if providing a list
of filenames, it will load all the files and join tables.
src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load.
also support to provide pandas `DataFrame` directly, will skip loading from filename.
if provided a list of filenames or pandas `DataFrame`, it will join the tables.
row_indices: indices of the expected rows to load. it should be a list,
every item can be a int number or a range `[start, end)` for the indices.
for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,
Expand All @@ -1249,20 +1250,32 @@ class CSVDataset(Dataset):
transform: transform to apply on the loaded items of a dictionary data.
kwargs: additional arguments for `pandas.merge()` API to join tables.

.. deprecated:: 0.8.0
``filename`` is deprecated, use ``src`` instead.

"""

@deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.")
def __init__(
self,
filename: Union[str, Sequence[str]],
src: Optional[Union[str, Sequence[str]]] = None, # also can be `DataFrame` or sequense of `DataFrame`
row_indices: Optional[Sequence[Union[int, str]]] = None,
col_names: Optional[Sequence[str]] = None,
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
col_groups: Optional[Dict[str, Sequence[str]]] = None,
transform: Optional[Callable] = None,
**kwargs,
):
files = ensure_tuple(filename)
dfs = [pd.read_csv(f) for f in files]
srcs = (src,) if not isinstance(src, (tuple, list)) else src
dfs: List = []
for i in srcs:
if isinstance(i, str):
dfs.append(pd.read_csv(i))
elif isinstance(i, pd.DataFrame):
dfs.append(i)
else:
raise ValueError("`src` must be file path or pandas `DataFrame`.")

data = convert_tables_to_dicts(
dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs
)
Expand Down
59 changes: 47 additions & 12 deletions monai/data/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union

import numpy as np
from torch.utils.data import IterableDataset as _TorchIterableDataset
Expand All @@ -18,7 +18,7 @@
from monai.data.utils import convert_tables_to_dicts
from monai.transforms import apply_transform
from monai.transforms.transform import Randomizable
from monai.utils import ensure_tuple, optional_import
from monai.utils import deprecated_arg, optional_import

pd, _ = optional_import("pandas")

Expand Down Expand Up @@ -147,8 +147,9 @@ class CSVIterableDataset(IterableDataset):
]

Args:
filename: the filename of CSV file to load. it can be a str, URL, path object or file-like object.
if providing a list of filenames, it will load all the files and join tables.
src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load.
also support to provide iter for stream input directly, will skip loading from filename.
if provided a list of filenames or iters, it will join the tables.
chunksize: rows of a chunk when loading iterable data from CSV files, default to 1000. more details:
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html.
buffer_size: size of the buffer to store the loaded chunks, if None, set to `2 x chunksize`.
Expand Down Expand Up @@ -177,11 +178,15 @@ class CSVIterableDataset(IterableDataset):
https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98.
kwargs: additional arguments for `pandas.merge()` API to join tables.

.. deprecated:: 0.8.0
``filename`` is deprecated, use ``src`` instead.

"""

@deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.")
def __init__(
self,
filename: Union[str, Sequence[str]],
src: Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]],
chunksize: int = 1000,
buffer_size: Optional[int] = None,
col_names: Optional[Sequence[str]] = None,
Expand All @@ -192,7 +197,7 @@ def __init__(
seed: int = 0,
**kwargs,
):
self.files = ensure_tuple(filename)
self.src = src
self.chunksize = chunksize
self.buffer_size = 2 * chunksize if buffer_size is None else buffer_size
self.col_names = col_names
Expand All @@ -201,16 +206,46 @@ def __init__(
self.shuffle = shuffle
self.seed = seed
self.kwargs = kwargs
self.iters = self.reset()
self.iters: List[Iterable] = self.reset()
super().__init__(data=None, transform=transform) # type: ignore

def reset(self, filename: Optional[Union[str, Sequence[str]]] = None):
if filename is not None:
# update files if necessary
self.files = ensure_tuple(filename)
self.iters = [pd.read_csv(f, chunksize=self.chunksize) for f in self.files]
@deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.")
def reset(self, src: Optional[Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]]] = None):
"""
Reset the pandas `TextFileReader` iterable object to read data. For more details, please check:
https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration.

Args:
src: if not None and provided the filename of CSV file, it can be a str, URL, path object
or file-like object to load. also support to provide iter for stream input directly,
will skip loading from filename. if provided a list of filenames or iters, it will join the tables.
default to `self.src`.

"""
src = self.src if src is None else src
srcs = (src,) if not isinstance(src, (tuple, list)) else src
self.iters = []
for i in srcs:
if isinstance(i, str):
self.iters.append(pd.read_csv(i, chunksize=self.chunksize))
elif isinstance(i, Iterable):
self.iters.append(i)
else:
raise ValueError("`src` must be file path or iterable object.")
return self.iters

def close(self):
"""
Close the pandas `TextFileReader` iterable objects.
If the input src is file path, TextFileReader was created internally, need to close it.
If the input src is iterable object, depends on users requirements whether to close it in this function.
For more details, please check:
https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration.

"""
for i in self.iters:
i.close()

def _flattened(self):
for chunks in zip(*self.iters):
yield from convert_tables_to_dicts(
Expand Down
2 changes: 0 additions & 2 deletions monai/networks/blocks/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
def monai_mish(x, inplace: bool = False):
return torch.nn.functional.mish(x, inplace=inplace)


else:

def monai_mish(x, inplace: bool = False):
Expand All @@ -31,7 +30,6 @@ def monai_mish(x, inplace: bool = False):
def monai_swish(x, inplace: bool = False):
return torch.nn.functional.silu(x, inplace=inplace)


else:

def monai_swish(x, inplace: bool = False):
Expand Down
37 changes: 30 additions & 7 deletions tests/test_csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import unittest

import numpy as np
import pandas as pd

from monai.data import CSVDataset
from monai.transforms import ToNumpyd
Expand Down Expand Up @@ -57,6 +58,7 @@ def prepare_csv_file(data, filepath):
filepath1 = os.path.join(tempdir, "test_data1.csv")
filepath2 = os.path.join(tempdir, "test_data2.csv")
filepath3 = os.path.join(tempdir, "test_data3.csv")
filepaths = [filepath1, filepath2, filepath3]
prepare_csv_file(test_data1, filepath1)
prepare_csv_file(test_data2, filepath2)
prepare_csv_file(test_data3, filepath3)
Expand All @@ -76,7 +78,7 @@ def prepare_csv_file(data, filepath):
)

# test multiple CSV files, join tables with kwargs
dataset = CSVDataset([filepath1, filepath2, filepath3], on="subject_id")
dataset = CSVDataset(filepaths, on="subject_id")
self.assertDictEqual(
{k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in dataset[3].items()},
{
Expand All @@ -102,7 +104,7 @@ def prepare_csv_file(data, filepath):

# test selected rows and columns
dataset = CSVDataset(
filename=[filepath1, filepath2, filepath3],
src=filepaths,
row_indices=[[0, 2], 3], # load row: 0, 1, 3
col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"],
)
Expand All @@ -120,7 +122,7 @@ def prepare_csv_file(data, filepath):

# test group columns
dataset = CSVDataset(
filename=[filepath1, filepath2, filepath3],
src=filepaths,
row_indices=[1, 3], # load row: 1, 3
col_names=["subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2"],
col_groups={"ehr": [f"ehr_{i}" for i in range(11)], "meta12": ["meta_1", "meta_2"]},
Expand All @@ -133,9 +135,7 @@ def prepare_csv_file(data, filepath):

# test transform
dataset = CSVDataset(
filename=[filepath1, filepath2, filepath3],
col_groups={"ehr": [f"ehr_{i}" for i in range(5)]},
transform=ToNumpyd(keys="ehr"),
src=filepaths, col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, transform=ToNumpyd(keys="ehr")
)
self.assertEqual(len(dataset), 5)
expected = [
Expand All @@ -151,7 +151,7 @@ def prepare_csv_file(data, filepath):

# test default values and dtype
dataset = CSVDataset(
filename=[filepath1, filepath2, filepath3],
src=filepaths,
col_names=["subject_id", "image", "ehr_1", "ehr_9", "meta_1"],
col_types={"image": {"type": str, "default": "No image"}, "ehr_1": {"type": int, "default": 0}},
how="outer", # generate NaN values in this merge mode
Expand All @@ -161,6 +161,29 @@ def prepare_csv_file(data, filepath):
self.assertEqual(type(dataset[-1]["ehr_1"]), int)
np.testing.assert_allclose(dataset[-1]["ehr_9"], 3.3537, rtol=1e-2)

# test pre-loaded DataFrame
df = pd.read_csv(filepath1)
dataset = CSVDataset(src=df)
self.assertDictEqual(
{k: round(v, 4) if not isinstance(v, str) else v for k, v in dataset[2].items()},
{
"subject_id": "s000002",
"label": 4,
"image": "./imgs/s000002.png",
"ehr_0": 3.7725,
"ehr_1": 4.2118,
"ehr_2": 4.6353,
},
)

# test pre-loaded multiple DataFrames, join tables with kwargs
dfs = [pd.read_csv(i) for i in filepaths]
dataset = CSVDataset(src=dfs, on="subject_id")
self.assertEqual(dataset[3]["subject_id"], "s000003")
self.assertEqual(dataset[3]["label"], 1)
self.assertEqual(round(dataset[3]["ehr_0"], 4), 3.3333)
self.assertEqual(dataset[3]["meta_0"], False)


if __name__ == "__main__":
unittest.main()
Loading