Skip to content

Commit

Permalink
3710 Add support to set delimiter for CSV files and change default to…
Browse files Browse the repository at this point in the history
… comma (#3711)

* [DLMED] add support to config delimiter

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] fix typo

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
Nic-Ma authored and wyli committed Feb 9, 2022
1 parent e964ba0 commit abe24cf
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 14 deletions.
6 changes: 5 additions & 1 deletion monai/data/csv_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
filename: str = "predictions.csv",
overwrite: bool = True,
flush: bool = False,
delimiter: str = ",",
) -> None:
"""
Args:
Expand All @@ -48,6 +49,8 @@ def __init__(
otherwise, will append new content to the CSV file.
flush: whether to write the cache data to CSV file immediately when `save_batch` and clear the cache.
default to False.
delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`.
to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
"""
self.output_dir = Path(output_dir)
Expand All @@ -59,6 +62,7 @@ def __init__(
os.remove(self._filepath)

self.flush = flush
self.delimiter = delimiter
self._data_index = 0

def finalize(self) -> None:
Expand All @@ -72,7 +76,7 @@ def finalize(self) -> None:
for k, v in self._cache_dict.items():
f.write(k)
for result in v.flatten():
f.write("," + str(result))
f.write(self.delimiter + str(result))
f.write("\n")
# clear cache content after writing
self.reset_cache()
Expand Down
8 changes: 7 additions & 1 deletion monai/handlers/classification_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self,
output_dir: str = "./",
filename: str = "predictions.csv",
delimiter: str = ",",
overwrite: bool = True,
batch_transform: Callable = lambda x: x,
output_transform: Callable = lambda x: x,
Expand All @@ -50,6 +51,8 @@ def __init__(
Args:
output_dir: if `saver=None`, output CSV file directory.
filename: if `saver=None`, name of the saved CSV file name.
delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`.
to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
overwrite: if `saver=None`, whether to overwriting existing file content, if True,
will clear the file before saving. otherwise, will append new content to the file.
batch_transform: a callable that is used to extract the `meta_data` dictionary of
Expand All @@ -74,6 +77,7 @@ def __init__(
self.save_rank = save_rank
self.output_dir = output_dir
self.filename = filename
self.delimiter = delimiter
self.overwrite = overwrite
self.batch_transform = batch_transform
self.output_transform = output_transform
Expand Down Expand Up @@ -153,6 +157,8 @@ def _finalize(self, _engine: Engine) -> None:

# save to CSV file only in the expected rank
if idist.get_rank() == self.save_rank:
saver = self.saver or CSVSaver(self.output_dir, self.filename, self.overwrite)
saver = self.saver or CSVSaver(
output_dir=self.output_dir, filename=self.filename, overwrite=self.overwrite, delimiter=self.delimiter
)
saver.save_batch(outputs, meta_dict)
saver.finalize()
5 changes: 3 additions & 2 deletions monai/handlers/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ class mean median max 5percentile 95percentile notnans
mean 6.2500 6.2500 7.0000 5.5750 6.9250 2.0000
save_rank: only the handler on specified rank will save to files in multi-gpus validation, default to 0.
delimiter: the delimiter character in CSV file, default to "\t".
delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`.
to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
output_type: expected output file type, supported types: ["csv"], default to "csv".
"""
Expand All @@ -83,7 +84,7 @@ def __init__(
batch_transform: Callable = lambda x: x,
summary_ops: Optional[Union[str, Sequence[str]]] = None,
save_rank: int = 0,
delimiter: str = "\t",
delimiter: str = ",",
output_type: str = "csv",
) -> None:
self.save_dir = save_dir
Expand Down
5 changes: 3 additions & 2 deletions monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def write_metrics_reports(
metrics: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]],
metric_details: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]],
summary_ops: Optional[Union[str, Sequence[str]]],
deli: str = "\t",
deli: str = ",",
output_type: str = "csv",
):
"""
Expand Down Expand Up @@ -88,7 +88,8 @@ class mean median max 5percentile 95percentile notnans
class1 6.0000 6.0000 6.0000 6.0000 6.0000 1.0000
mean 6.2500 6.2500 7.0000 5.5750 6.9250 2.0000
deli: the delimiter character in the file, default to "\t".
deli: the delimiter character in the saved file, default to "," as the default output type is `csv`.
to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
output_type: expected output file type, supported types: ["csv"], default to "csv".
"""
Expand Down
7 changes: 6 additions & 1 deletion monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def __init__(
saver: Optional[CSVSaver] = None,
output_dir: PathLike = "./",
filename: str = "predictions.csv",
delimiter: str = ",",
overwrite: bool = True,
flush: bool = True,
allow_missing_keys: bool = False,
Expand All @@ -696,6 +697,8 @@ def __init__(
the saver must provide `save(data, meta_data)` and `finalize()` APIs.
output_dir: if `saver=None`, specify the directory to save the CSV file.
filename: if `saver=None`, specify the name of the saved CSV file.
delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`.
to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
overwrite: if `saver=None`, indicate whether to overwriting existing CSV file content, if True,
will clear the file before saving. otherwise, will append new content to the CSV file.
flush: if `saver=None`, indicate whether to write the cache data to CSV file immediately
Expand All @@ -707,7 +710,9 @@ def __init__(
super().__init__(keys, allow_missing_keys)
if len(self.keys) != 1:
raise ValueError("only 1 key is allowed when saving the classification result.")
self.saver = saver or CSVSaver(output_dir, filename, overwrite, flush)
self.saver = saver or CSVSaver(
output_dir=output_dir, filename=filename, overwrite=overwrite, flush=flush, delimiter=delimiter
)
self.flush = flush
self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_csv_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
class TestCSVSaver(unittest.TestCase):
def test_saved_content(self):
with tempfile.TemporaryDirectory() as tempdir:
saver = CSVSaver(output_dir=tempdir, filename="predictions.csv")
saver = CSVSaver(output_dir=tempdir, filename="predictions.csv", delimiter="\t")
meta_data = {"filename_or_obj": ["testfile" + str(i) for i in range(8)]}
saver.save_batch(torch.zeros(8), meta_data)
saver.finalize()
filepath = os.path.join(tempdir, "predictions.csv")
self.assertTrue(os.path.exists(filepath))
with open(filepath) as f:
reader = csv.reader(f)
reader = csv.reader(f, delimiter="\t")
i = 0
for row in reader:
self.assertEqual(row[0], "testfile" + str(i))
Expand Down
6 changes: 3 additions & 3 deletions tests/test_handler_classification_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def _train_func(engine, batch):
engine = Engine(_train_func)

# set up testing handler
saver = CSVSaver(output_dir=tempdir, filename="predictions2.csv")
ClassificationSaver(output_dir=tempdir, filename="predictions1.csv").attach(engine)
saver = CSVSaver(output_dir=tempdir, filename="predictions2.csv", delimiter="\t")
ClassificationSaver(output_dir=tempdir, filename="predictions1.csv", delimiter="\t").attach(engine)
ClassificationSaver(saver=saver).attach(engine)

data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8)]}]
Expand All @@ -46,7 +46,7 @@ def _test_file(filename):
filepath = os.path.join(tempdir, filename)
self.assertTrue(os.path.exists(filepath))
with open(filepath) as f:
reader = csv.reader(f)
reader = csv.reader(f, delimiter="\t")
i = 0
for row in reader:
self.assertEqual(row[0], "testfile" + str(i))
Expand Down
1 change: 1 addition & 0 deletions tests/test_handler_metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_content(self):
metric_details=["metric3", "metric4"],
batch_transform=lambda x: x["image_meta_dict"],
summary_ops=["mean", "median", "max", "5percentile", "95percentile", "notnans"],
delimiter="\t",
)
# set up engine
data = [
Expand Down
1 change: 1 addition & 0 deletions tests/test_handler_metrics_saver_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _run(self, tempdir):
metric_details=["metric3", "metric4"],
batch_transform=lambda x: x["image_meta_dict"],
summary_ops="*",
delimiter="\t",
)

def _val_func(engine, batch):
Expand Down
8 changes: 6 additions & 2 deletions tests/test_save_classificationd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def test_saved_content(self):
},
]

saver = CSVSaver(output_dir=Path(tempdir), filename="predictions2.csv", overwrite=False, flush=False)
saver = CSVSaver(
output_dir=Path(tempdir), filename="predictions2.csv", overwrite=False, flush=False, delimiter="\t"
)
# set up test transforms
post_trans = Compose(
[
Expand All @@ -52,6 +54,7 @@ def test_saved_content(self):
meta_keys=None,
output_dir=Path(tempdir),
filename="predictions1.csv",
delimiter="\t",
overwrite=True,
),
# 2rd saver only saves data into the cache, manually finalize later
Expand All @@ -75,6 +78,7 @@ def test_saved_content(self):
meta_keys="image_meta_dict", # specify meta key, so no need to copy anymore
output_dir=tempdir,
filename="predictions1.csv",
delimiter="\t",
overwrite=False,
)
d = decollate_batch(data[2])
Expand All @@ -85,7 +89,7 @@ def _test_file(filename, count):
filepath = os.path.join(tempdir, filename)
self.assertTrue(os.path.exists(filepath))
with open(filepath) as f:
reader = csv.reader(f)
reader = csv.reader(f, delimiter="\t")
i = 0
for row in reader:
self.assertEqual(row[0], "testfile" + str(i))
Expand Down

0 comments on commit abe24cf

Please sign in to comment.