Skip to content

Commit

Permalink
Fix LighterFileWriter (#130)
Browse files Browse the repository at this point in the history
* Bump version

* Fix LighterFileWriter + add more test cases

---------

Co-authored-by: GitHub Action <action@github.com>
  • Loading branch information
surajpaib and actions-user authored Jul 29, 2024
1 parent 6081ebb commit c88d891
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 10 deletions.
12 changes: 7 additions & 5 deletions lighter/callbacks/writer/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ class LighterFileWriter(LighterBaseWriter):
for a more permanent solution, it can be added to the `self.writers` dictionary.
Args:
directory (Union[str, Path]): Directory where the files should be written.
path (Union[str, Path]): Directory where the files should be written.
writer (Union[str, Callable]): Name of the writer function registered in `self.writers` or a custom writer function.
Available writers: "tensor", "image", "video", "itk_nrrd", "itk_seg_nrrd", "itk_nifti".
A custom writer function must take two arguments: `path` and `tensor`, and write the tensor to the specified path.
`tensor` is a single tensor without the batch dimension.
"""

def __init__(self, directory: Union[str, Path], writer: Union[str, Callable]) -> None:
super().__init__(directory, writer)
def __init__(self, path: Union[str, Path], writer: Union[str, Callable]) -> None:
super().__init__(path, writer)

@property
def writers(self) -> Dict[str, Callable]:
Expand All @@ -49,9 +49,11 @@ def write(self, tensor: torch.Tensor, id: Union[int, str]) -> None:
tensor (Tensor): Tensor, without the batch dimension, to be written.
id (Union[int, str]): Identifier, used for file-naming.
"""
if not self.path.is_dir():
raise RuntimeError(f"LighterFileWriter expects a directory path, got {self.path}")

# Determine the path for the file based on prediction count. The suffix must be added by the writer function.
path = self.directory / str(id)
path.parent.mkdir(exist_ok=True, parents=True)
path = self.path / str(id)
# Write the tensor to the file.
self.writer(path, tensor)

Expand Down
3 changes: 2 additions & 1 deletion lighter/utils/dynamic_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def import_module_from_path(module_name: str, module_path: str) -> None:
# Based on https://stackoverflow.com/a/41595552.

if module_name in sys.modules:
raise ValueError(f"{module_name} has already been imported as module.")
logger.warning(f"{module_name} has already been imported as module.")
return

module_path = Path(module_path).resolve() / "__init__.py"
if not module_path.is_file():
Expand Down
2 changes: 1 addition & 1 deletion projects/cifar10/experiments/monai_bundle_prototype.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ trainer:
logger: False
callbacks:
- _target_: lighter.callbacks.LighterFileWriter
directory: '$f"{@project}/predictions"'
path: '$f"{@project}/predictions"'
writer: tensor

system:
Expand Down
14 changes: 12 additions & 2 deletions tests/integration/test_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,23 @@
"fit",
# Config fiile
"./projects/cifar10/experiments/monai_bundle_prototype.yaml",
)
),
( # Method name
"test",
# Config fiile
"./projects/cifar10/experiments/monai_bundle_prototype.yaml",
),
( # Method name
"predict",
# Config fiile
"./projects/cifar10/experiments/monai_bundle_prototype.yaml",
),
],
)
@pytest.mark.slow
def test_trainer_method(method_name: str, config: str):
""" """
kwargs = {"config": [config, test_overrides]}

kwargs = {"config": [config, test_overrides]}
func_return = run(method_name, **kwargs)
assert func_return is None
1 change: 0 additions & 1 deletion tests/integration/test_overrides.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ trainer#fast_dev_run: True
trainer#accelerator: cpu
system#batch_size: 16
system#num_workers: 2
trainer#callbacks: null

0 comments on commit c88d891

Please sign in to comment.