Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 000 add datasetadapter features #187

Merged
merged 24 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f584469
wip
Louis-Dupont Sep 12, 2023
b7f8021
wip
Louis-Dupont Sep 12, 2023
6b995e3
Merge branch 'master' into feature/SG-000-add_datasetadapter_features
Louis-Dupont Sep 12, 2023
498e375
Merge branch 'master' into feature/SG-000-add_datasetadapter_features
Louis-Dupont Sep 14, 2023
0e8174a
refacto how we handle the processing
Louis-Dupont Sep 14, 2023
c9b4cbc
remove useless code
Louis-Dupont Sep 14, 2023
ef543e7
cleanup
Louis-Dupont Sep 14, 2023
b749f9b
rename folder to sample_preprocessor
Louis-Dupont Sep 14, 2023
c37f101
rename modules
Louis-Dupont Sep 14, 2023
4e2a7ce
Merge branch 'master' into feature/SG-000-add_datasetadapter_features
Louis-Dupont Sep 14, 2023
fb906f8
Merge branch 'master' into feature/SG-000-add_datasetadapter_features
Louis-Dupont Sep 14, 2023
090d86d
Merge branch 'master' into feature/SG-000-add_datasetadapter_features
Louis-Dupont Sep 18, 2023
97768de
fix merge
Louis-Dupont Sep 18, 2023
bb1f147
minor fix
Louis-Dupont Sep 18, 2023
d60ea6b
group datasetadapter with dataconfig
Louis-Dupont Sep 18, 2023
9dfea08
wip
Louis-Dupont Sep 18, 2023
4fec8bc
add some tests
Louis-Dupont Sep 18, 2023
824b4b9
add
Louis-Dupont Sep 19, 2023
19b24a4
fix to work on local path
Louis-Dupont Sep 19, 2023
4d1b8e4
add is_batch to parans
Louis-Dupont Sep 19, 2023
e250ea3
Merge branch 'master' into feature/SG-000-add_datasetadapter_features
Louis-Dupont Sep 20, 2023
67adbbf
config to data_config
Louis-Dupont Sep 20, 2023
5cf015e
undo unwanted changes
Louis-Dupont Sep 20, 2023
e4bdf95
update test
Louis-Dupont Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/data_gradients/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf

from data_gradients.config.data.typing import FeatureExtractorsType
from data_gradients.dataset_adapters.config.typing import FeatureExtractorsType
from data_gradients.feature_extractors import AbstractFeatureExtractor
from data_gradients.common.factories import FeatureExtractorsFactory, ListFactory

Expand Down
32 changes: 16 additions & 16 deletions src/data_gradients/dataset_adapters/base_adapter.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
from abc import ABC
from typing import List, Iterable, Sized, Tuple
from typing import List, Tuple

import torch

from data_gradients.config.data.typing import SupportedDataType
from data_gradients.config.data.data_config import DataConfig
from data_gradients.dataset_adapters.config.data_config import DataConfig

from data_gradients.dataset_adapters.formatters.base import BatchFormatter
from data_gradients.dataset_adapters.output_mapper.dataset_output_mapper import DatasetOutputMapper
from data_gradients.dataset_adapters.config.typing import SupportedDataType


class BaseDatasetAdapter(ABC):
"""Wrap a dataset and applies transformations on data points.
It acts as a base class for specific dataset adapters that cater to specific data structures.

:param data_iterable: Iterable object that yields data points from the dataset.
:param formatter: Instance of BatchFormatter that is used to validate and format the batches of images and labels
into the appropriate format for a given task.
:param data_config: Instance of DataConfig class that manages dataset/dataloader configurations.
"""

def __init__(
self,
data_iterable: Iterable[SupportedDataType],
dataset_output_mapper: DatasetOutputMapper,
formatter: BatchFormatter,
data_config: DataConfig,
class_names: List[str],
):
self.data_iterable = data_iterable
self.data_config = data_config

self.dataset_output_mapper = dataset_output_mapper
Expand All @@ -53,14 +50,17 @@ def resolve_class_names_to_use(class_names: List[str], class_names_to_use: List[
raise RuntimeError(f"You defined `class_names_to_use` with classes that are not listed in `class_names`: {invalid_class_names_to_use}")
return class_names_to_use or class_names

def __len__(self) -> int:
"""Length of the dataset if available. Otherwise, None."""
return len(self.data_iterable) if isinstance(self.data_iterable, Sized) else None
def adapt(self, data: SupportedDataType) -> Tuple[torch.Tensor, torch.Tensor]:
"""Adapt an input data (Batch or Sample) into a standard format.

def __iter__(self) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]:
"""Iterate over the dataset and return a batch of images and labels."""
for data in self.data_iterable:
# data can be a batch or a sample
images, labels = self.dataset_output_mapper.extract(data)
images, labels = self.formatter.format(images, labels)
yield images, labels
:param data: Input data to be adapted.
- Can represent a batch or a sample.
- Can be structured in a wide range of formats. (list, dict, ...)
- Can be formatted in a wide range of formats. (image: HWC, CHW, ... - label: label_cxcywh, xyxy_label, ...)
:return: Tuple of images and labels.
- Image will be formatted to (BS, H, W, C) - BS = 1 if original data is a single sample
- Label will be formatted to a standard format that depends on the task.
"""
images, labels = self.dataset_output_mapper.extract(data)
images, labels = self.formatter.format(images, labels)
return images, labels
13 changes: 5 additions & 8 deletions src/data_gradients/dataset_adapters/classification_adapter.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import List, Optional, Iterable, Callable
from typing import List, Optional, Callable
import torch

from data_gradients.config.data.typing import SupportedDataType
from data_gradients.dataset_adapters.config.typing import SupportedDataType
from data_gradients.dataset_adapters.base_adapter import BaseDatasetAdapter
from data_gradients.dataset_adapters.output_mapper.dataset_output_mapper import DatasetOutputMapper
from data_gradients.config.data.data_config import ClassificationDataConfig
from data_gradients.dataset_adapters.config.data_config import ClassificationDataConfig
from data_gradients.dataset_adapters.formatters.classification import ClassificationBatchFormatter


class ClassificationDatasetAdapter(BaseDatasetAdapter):
"""Wrap a classification dataset so that it would return standardized tensors.

:param data_iterable: Iterable object that yields data points from the dataset.
:param cache_path: The filename of the cache file.
:param n_classes: The number of classes.
:param class_names: List of class names.
Expand All @@ -24,18 +23,16 @@ class ClassificationDatasetAdapter(BaseDatasetAdapter):

def __init__(
self,
data_iterable: Iterable[SupportedDataType],
cache_path: Optional[str] = None,
n_classes: Optional[int] = None,
class_names: Optional[List[str]] = None,
class_names_to_use: Optional[List[str]] = None,
images_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
labels_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
is_batch: Optional[bool] = None,
n_image_channels: int = 3,
data_config: Optional[ClassificationDataConfig] = None,
):
self.data_iterable = data_iterable

class_names = self.resolve_class_names(class_names=class_names, n_classes=n_classes)
class_names_to_use = self.resolve_class_names_to_use(class_names=class_names, class_names_to_use=class_names_to_use)

Expand All @@ -44,6 +41,7 @@ def __init__(
cache_path=cache_path,
images_extractor=images_extractor,
labels_extractor=labels_extractor,
is_batch=is_batch,
)

dataset_output_mapper = DatasetOutputMapper(data_config=data_config)
Expand All @@ -54,7 +52,6 @@ def __init__(
n_image_channels=n_image_channels,
)
super().__init__(
data_iterable=data_iterable,
dataset_output_mapper=dataset_output_mapper,
formatter=formatter,
data_config=data_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from data_gradients.dataset_adapters.output_mapper.tensor_extractor import NestedDataLookup
from data_gradients.config.data.typing import SupportedDataType
from data_gradients.dataset_adapters.config.typing import SupportedDataType
from data_gradients.utils.detection import XYXYConverter

# This is used as a prefix to recognize parameters that are not cachable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from typing import Dict, Optional, Callable, Union

import data_gradients
from data_gradients.config.data.questions import Question, ask_question, text_to_yellow
from data_gradients.config.data.caching_utils import TensorExtractorResolver, XYXYConverterResolver
from data_gradients.config.data.typing import SupportedDataType, JSONDict
from data_gradients.dataset_adapters.config.questions import Question, ask_question, text_to_yellow
from data_gradients.dataset_adapters.config.caching_utils import TensorExtractorResolver, XYXYConverterResolver
from data_gradients.dataset_adapters.config.typing import SupportedDataType, JSONDict
from data_gradients.utils.detection import XYXYConverter
from data_gradients.utils.utils import safe_json_load, write_json

Expand Down Expand Up @@ -109,6 +109,11 @@ def to_json(self) -> JSONDict:
}
return json_dict

@property
def is_completely_initialized(self) -> bool:
"""Check if all the attributes are set or not."""
return all(v is not None for v in self.to_json().values())

def _fill_missing_params_with_cache(self, path: str):
"""Load an instance of DataConfig directly from a cache file.
:param path: Full path of the cache file. This should end with ".json" extension.
Expand Down
13 changes: 5 additions & 8 deletions src/data_gradients/dataset_adapters/detection_adapter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from typing import List, Optional, Iterable, Callable
from typing import List, Optional, Callable

import torch

from data_gradients.config.data.typing import SupportedDataType
from data_gradients.dataset_adapters.config.typing import SupportedDataType
from data_gradients.dataset_adapters.base_adapter import BaseDatasetAdapter
from data_gradients.dataset_adapters.output_mapper.dataset_output_mapper import DatasetOutputMapper
from data_gradients.dataset_adapters.formatters.detection import DetectionBatchFormatter
from data_gradients.config.data.data_config import DetectionDataConfig
from data_gradients.dataset_adapters.config.data_config import DetectionDataConfig


class DetectionDatasetAdapter(BaseDatasetAdapter):
"""Wrap a detection dataset so that it would return standardized tensors.

:param data_iterable: Iterable object that yields data points from the dataset.
:param cache_path: The filename of the cache file.
:param n_classes: The number of classes.
:param class_names: List of class names.
Expand All @@ -27,20 +26,18 @@ class DetectionDatasetAdapter(BaseDatasetAdapter):

def __init__(
self,
data_iterable: Iterable[SupportedDataType],
cache_path: Optional[str] = None,
n_classes: Optional[int] = None,
class_names: Optional[List[str]] = None,
class_names_to_use: Optional[List[str]] = None,
images_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
labels_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
is_batch: Optional[bool] = None,
is_label_first: Optional[bool] = None,
bbox_format: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
n_image_channels: int = 3,
data_config: Optional[DetectionDataConfig] = None,
):
self.data_iterable = data_iterable

class_names = self.resolve_class_names(class_names=class_names, n_classes=n_classes)
class_names_to_use = self.resolve_class_names_to_use(class_names=class_names, class_names_to_use=class_names_to_use)

Expand All @@ -49,6 +46,7 @@ def __init__(
cache_path=cache_path,
images_extractor=images_extractor,
labels_extractor=labels_extractor,
is_batch=is_batch,
is_label_first=is_label_first,
xyxy_converter=bbox_format,
)
Expand All @@ -61,7 +59,6 @@ def __init__(
n_image_channels=n_image_channels,
)
super().__init__(
data_iterable=data_iterable,
dataset_output_mapper=dataset_output_mapper,
formatter=formatter,
data_config=data_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from data_gradients.dataset_adapters.formatters.base import BatchFormatter
from data_gradients.dataset_adapters.formatters.utils import DatasetFormatError, check_images_shape
from data_gradients.dataset_adapters.formatters.utils import ensure_channel_first
from data_gradients.config.data.data_config import ClassificationDataConfig
from data_gradients.dataset_adapters.config.data_config import ClassificationDataConfig


class UnsupportedClassificationBatchFormatError(DatasetFormatError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from data_gradients.dataset_adapters.utils import check_all_integers
from data_gradients.dataset_adapters.formatters.base import BatchFormatter
from data_gradients.dataset_adapters.formatters.utils import check_images_shape, ensure_channel_first, drop_nan
from data_gradients.config.data.data_config import DetectionDataConfig
from data_gradients.dataset_adapters.config.data_config import DetectionDataConfig
from data_gradients.dataset_adapters.formatters.utils import DatasetFormatError


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from data_gradients.dataset_adapters.formatters.base import BatchFormatter
from data_gradients.dataset_adapters.utils import check_all_integers, to_one_hot
from data_gradients.config.data.data_config import SegmentationDataConfig
from data_gradients.dataset_adapters.config.data_config import SegmentationDataConfig
from data_gradients.dataset_adapters.formatters.utils import DatasetFormatError, check_images_shape, ensure_channel_first, drop_nan


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torchvision.transforms import transforms

from data_gradients.dataset_adapters.output_mapper.tensor_extractor import get_tensor_extractor_options
from data_gradients.config.data.data_config import DataConfig
from data_gradients.config.data.questions import Question, text_to_yellow
from data_gradients.dataset_adapters.config.data_config import DataConfig
from data_gradients.dataset_adapters.config.questions import Question, text_to_yellow

SupportedData = Union[Tuple, List, Mapping, Tuple, List]

Expand Down
11 changes: 5 additions & 6 deletions src/data_gradients/dataset_adapters/segmentation_adapter.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from typing import List, Optional, Iterable, Callable
from typing import List, Optional, Callable

import torch

from data_gradients.config.data.typing import SupportedDataType
from data_gradients.dataset_adapters.config.typing import SupportedDataType

from data_gradients.dataset_adapters.base_adapter import BaseDatasetAdapter
from data_gradients.dataset_adapters.output_mapper.dataset_output_mapper import DatasetOutputMapper
from data_gradients.dataset_adapters.formatters.segmentation import SegmentationBatchFormatter
from data_gradients.config.data.data_config import SegmentationDataConfig
from data_gradients.dataset_adapters.config.data_config import SegmentationDataConfig


class SegmentationDatasetAdapter(BaseDatasetAdapter):
"""Wrap a segmentation dataset so that it would return standardized tensors.

:param data_iterable: Iterable object that yields data points from the dataset.
:param cache_path: The filename of the cache file.
:param n_classes: The number of classes.
:param class_names: List of class names.
Expand All @@ -27,13 +26,13 @@ class SegmentationDatasetAdapter(BaseDatasetAdapter):

def __init__(
self,
data_iterable: Iterable[SupportedDataType],
cache_path: Optional[str] = None,
n_classes: Optional[int] = None,
class_names: Optional[List[str]] = None,
class_names_to_use: Optional[List[str]] = None,
images_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
labels_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
is_batch: Optional[bool] = None,
n_image_channels: int = 3,
threshold_soft_labels: float = 0.5,
data_config: Optional[SegmentationDataConfig] = None,
Expand All @@ -46,6 +45,7 @@ def __init__(
cache_path=cache_path,
images_extractor=images_extractor,
labels_extractor=labels_extractor,
is_batch=is_batch,
)

dataset_output_mapper = DatasetOutputMapper(data_config=data_config)
Expand All @@ -57,7 +57,6 @@ def __init__(
threshold_value=threshold_soft_labels,
)
super().__init__(
data_iterable=data_iterable,
dataset_output_mapper=dataset_output_mapper,
formatter=formatter,
data_config=data_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from data_gradients.utils.data_classes import SegmentationSample
from data_gradients.visualize.seaborn_renderer import KDEPlotOptions
from data_gradients.feature_extractors.abstract_feature_extractor import AbstractFeatureExtractor
from data_gradients.sample_iterables import contours
from data_gradients.sample_preprocessor.utils import contours


@register_feature_extractor()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from data_gradients.utils.data_classes import SegmentationSample
from data_gradients.visualize.seaborn_renderer import KDEPlotOptions
from data_gradients.feature_extractors.abstract_feature_extractor import AbstractFeatureExtractor
from data_gradients.sample_iterables import contours
from data_gradients.sample_preprocessor.utils import contours


@register_feature_extractor()
Expand Down
Loading