Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 845 add dataconfig with cache #93

Merged
merged 53 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
0938d8e
remove title from plits
Louis-Dupont Jun 6, 2023
32231e1
Merge branch 'master' into feature/SG-920-remove_title_from_plots
Louis-Dupont Jun 6, 2023
e20666d
rename
Louis-Dupont Jun 7, 2023
7f285c8
Merge branch 'master' into feature/SG-920-remove_title_from_plots
Louis-Dupont Jun 7, 2023
bebdd0e
wip
Louis-Dupont Jun 7, 2023
4834035
update
Louis-Dupont Jun 7, 2023
1aab0c9
update
Louis-Dupont Jun 7, 2023
dc2f1e4
improve doc
Louis-Dupont Jun 7, 2023
7d1f0a1
update
Louis-Dupont Jun 7, 2023
9cbe772
Merge branch 'master' into feature/SG-920-remove_title_from_plots
Louis-Dupont Jun 7, 2023
726d4de
fix
Louis-Dupont Jun 7, 2023
01f9c3b
add change from master
Louis-Dupont Jun 8, 2023
5d2d9e4
Big refacto simplification
Louis-Dupont Jun 8, 2023
b2e61ea
Merge branch 'master' into feature/SG-920-remove_title_from_plots
Louis-Dupont Jun 11, 2023
62e629f
rename data_config
Louis-Dupont Jun 11, 2023
010c278
working version
Louis-Dupont Jun 11, 2023
c85c504
Merge branch 'master' into feature/SG-920-remove_title_from_plots
Louis-Dupont Jun 12, 2023
e969424
remove unused code
Louis-Dupont Jun 12, 2023
8d6716e
Merge branch 'master' into feature/SG-920-remove_title_from_plots
Louis-Dupont Jun 13, 2023
28913b0
Merge branch 'master' into feature/SG-920-remove_title_from_plots
Louis-Dupont Jun 13, 2023
d5a2ce9
wip
Louis-Dupont Jun 13, 2023
53457ff
wip
Louis-Dupont Jun 14, 2023
5f0623f
Merge branch 'master' into feature/SG-845-add-dataconfig-with-cache
Louis-Dupont Jun 14, 2023
b41c890
wip
Louis-Dupont Jun 14, 2023
f613a45
clean and introduce resolver
Louis-Dupont Jun 14, 2023
d3cbb77
working version
Louis-Dupont Jun 14, 2023
4e51fe8
moving to folder
Louis-Dupont Jun 15, 2023
21505c9
add docstring
Louis-Dupont Jun 15, 2023
34ca038
add some doc
Louis-Dupont Jun 15, 2023
de14427
add docstring
Louis-Dupont Jun 15, 2023
ffabada
Merge branch 'master' into feature/SG-845-add-dataconfig-with-cache
Louis-Dupont Jun 18, 2023
0fe7365
fix merge
Louis-Dupont Jun 18, 2023
d8536cd
fix
Louis-Dupont Jun 18, 2023
5075005
update, simplify and add use_cache option
Louis-Dupont Jun 19, 2023
24293a3
Merge branch 'master' into feature/SG-845-add-dataconfig-with-cache
Louis-Dupont Jun 19, 2023
df249c1
show message when error
Louis-Dupont Jun 19, 2023
e17b13d
Merge branch 'master' into feature/SG-845-add-dataconfig-with-cache
Louis-Dupont Jun 20, 2023
bb6c04b
fix
Louis-Dupont Jun 20, 2023
9068dfb
fix
Louis-Dupont Jun 20, 2023
66fdca3
improve exception
Louis-Dupont Jun 20, 2023
9ce71f2
improve doc
Louis-Dupont Jun 20, 2023
d659024
refine
Louis-Dupont Jun 20, 2023
4520128
rename typing
Louis-Dupont Jun 21, 2023
f1d4112
improve tensor extractor prints
Louis-Dupont Jun 21, 2023
24246f3
move data_config to inside our code
Louis-Dupont Jun 21, 2023
4f27a96
improve exception explanation
Louis-Dupont Jun 21, 2023
eb701de
remove exception when unsupported object
Louis-Dupont Jun 21, 2023
fc8b5ff
wrap logic inside _extract_images
Louis-Dupont Jun 21, 2023
5c01ff6
add colors for all questions
Louis-Dupont Jun 21, 2023
4fe9905
fix visual bug
Louis-Dupont Jun 21, 2023
5d51272
undo unwanted change
Louis-Dupont Jun 21, 2023
3cf019f
add new case of DataLookupError
Louis-Dupont Jun 21, 2023
c9aa4e8
minor update
Louis-Dupont Jun 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/example_detection_super_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
val_data=val_loader,
class_names=train_loader.dataset.classes,
batches_early_stop=20,
use_cache=True, # With this we will be asked about the dataset information only once
)

analyzer.run()
1 change: 1 addition & 0 deletions examples/segmentation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
# Optionals
images_extractor=None,
labels_extractor=None,
use_cache=True,
threshold_soft_labels=0.5,
batches_early_stop=75,
)
Expand Down
122 changes: 73 additions & 49 deletions src/data_gradients/batch_processors/adapters/dataset_adapter.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,92 @@
from typing import Optional, Callable, Union, Tuple, List, Mapping, Any
from typing import Callable, Union, Tuple, List, Mapping

import PIL
import numpy as np
import torch
from torchvision.transforms import transforms

from data_gradients.batch_processors.adapters.tensor_extractor import TensorExtractor
from data_gradients.batch_processors.adapters.tensor_extractor import get_tensor_extractor_options
from data_gradients.config.data.data_config import DataConfig
from data_gradients.config.data.questions import Question, text_to_yellow

SupportedData = Union[Tuple, List, Mapping, Tuple, List]


class DatasetAdapter:
"""Class responsible to convert raw batch (coming from dataloader) into a batch of image and a batch of labels."""

def __init__(self, images_extractor: Optional[Callable] = None, labels_extractor: Optional[Callable] = None):
"""
:param images_extractor: (Optional) function that takes the dataloader output and extract the images.
If None, the user will need to input it manually in a following prompt.
:param labels_extractor: (Optional) function that takes the dataloader output and extract the labels.
If None, the user will need to input it manually in a following prompt.
"""
self._tensor_extractor = {0: images_extractor, 1: labels_extractor}
def __init__(self, data_config: DataConfig):
self.data_config = data_config

def extract(self, objs: Union[Tuple, List, Mapping]) -> Tuple[torch.Tensor, torch.Tensor]:
def extract(self, data: SupportedData) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert raw batch (coming from dataloader) into a batch of image and a batch of labels.

:param objs: Raw batch (coming from dataloader without any modification).
:param data: Raw batch (coming from dataloader without any modification).
:return:
- images: Batch of images
- labels: Batch of labels
"""
if isinstance(objs, (Tuple, List)) and len(objs) == 2:
images = objs[0] if isinstance(objs[0], torch.Tensor) else self._to_tensor(objs[0], tuple_idx=0)
labels = objs[1] if isinstance(objs[1], torch.Tensor) else self._to_tensor(objs[1], tuple_idx=1)
elif isinstance(objs, (Mapping, Tuple, List)):
images = self._extract_tensor_from_container(objs, 0)
labels = self._extract_tensor_from_container(objs, 1)
else:
raise NotImplementedError(f"Got object {type(objs)} from Iterator - supporting dict, tuples and lists Only!")
return images, labels

def _to_tensor(self, objs: Union[np.ndarray, PIL.Image.Image, Mapping], tuple_idx: int) -> torch.Tensor:
if isinstance(objs, np.ndarray):
return torch.from_numpy(objs)
elif isinstance(objs, PIL.Image.Image):
return transforms.ToTensor()(objs)
images = self._extract_images(data)
labels = self._extract_labels(data)
return self._to_torch(images), self._to_torch(labels)

def _extract_images(self, data: SupportedData) -> torch.Tensor:
images_extractor = self._get_images_extractor(data)
ofrimasad marked this conversation as resolved.
Show resolved Hide resolved
return images_extractor(data)

def _extract_labels(self, data: SupportedData) -> torch.Tensor:
labels_extractor = self._get_labels_extractor(data)
return labels_extractor(data)

def _get_images_extractor(self, data: SupportedData) -> Callable[[SupportedData], torch.Tensor]:
if self.data_config.images_extractor is not None:
return self.data_config.get_images_extractor()

# We use the heuristic that a tuple of 2 should represent (image, label) in this order
if isinstance(data, (Tuple, List)) and len(data) == 2:
if isinstance(data[0], (torch.Tensor, np.ndarray, PIL.Image.Image)):
self.data_config.images_extractor = "[0]" # We save it for later use
return self.data_config.get_images_extractor() # This will return a callable

# Otherwise, we ask the user how to map data -> image
if isinstance(data, (Tuple, List, Mapping, Tuple, List)):
description, options = get_tensor_extractor_options(data)
question = Question(question=f"Which tensor represents your {text_to_yellow('Image(s)')} ?", options=options)
return self.data_config.get_images_extractor(question=question, hint=description)

raise NotImplementedError(
f"Got object {type(data)} from Data Iterator which is not supported!\n"
f"Please implement a custom `images_extractor` for your dataset. "
f"You can find more detail about this in our documentation: https://github.com/Deci-AI/data-gradients"
)

def _get_labels_extractor(self, data: SupportedData) -> Callable[[SupportedData], torch.Tensor]:
if self.data_config.labels_extractor is not None:
return self.data_config.get_labels_extractor()

# We use the heuristic that a tuple of 2 should represent (image, label) in this order
if isinstance(data, (Tuple, List)) and len(data) == 2:
if isinstance(data[1], (torch.Tensor, np.ndarray, PIL.Image.Image)):
self.data_config.labels_extractor = "[1]" # We save it for later use
return self.data_config.get_labels_extractor() # This will return a callable

# Otherwise, we ask the user how to map data -> labels
if isinstance(data, (Tuple, List, Mapping, Tuple, List)):
description, options = get_tensor_extractor_options(data)
question = Question(question=f"Which tensor represents your {text_to_yellow('Label(s)')} ?", options=options)
return self.data_config.get_labels_extractor(question=question, hint=description)

raise NotImplementedError(
f"Got object {type(data)} from Data Iterator which is not supported!\n"
f"Please implement a custom `labels_extractor` for your dataset. "
f"You can find more detail about this in our documentation: https://github.com/Deci-AI/data-gradients"
)

@staticmethod
def _to_torch(tensor: Union[np.ndarray, PIL.Image.Image, torch.Tensor]) -> torch.Tensor:
if isinstance(tensor, np.ndarray):
return torch.from_numpy(tensor)
elif isinstance(tensor, PIL.Image.Image):
return transforms.ToTensor()(tensor)
else:
return self._extract_tensor_from_container(objs=objs, tuple_idx=tuple_idx)

def _extract_tensor_from_container(self, objs: Any, tuple_idx: int) -> torch.Tensor:
mapping_fn = self._get_tensor_extractor(tuple_idx=tuple_idx, objs=objs)
return mapping_fn(objs)

def _get_tensor_extractor(self, objs: Any, tuple_idx: int) -> Union[Callable, TensorExtractor]:
if self._tensor_extractor[tuple_idx] is None:
self._tensor_extractor[tuple_idx] = TensorExtractor(objs=objs, name="image(s)" if (tuple_idx == 0) else "label(s)")
return self._tensor_extractor[tuple_idx]

@property
def images_route(self) -> List[str]:
"""Represent the path (route) to extract the images from the raw batch (coming from dataloader)."""
tensor_finder = self._tensor_extractor[0]
return tensor_finder.path_to_tensor if isinstance(tensor_finder, TensorExtractor) else []

@property
def labels_route(self) -> List[str]:
"""Represent the path (route) to extract the labels from the raw batch (coming from dataloader)."""
tensor_finder = self._tensor_extractor[1]
return tensor_finder.path_to_tensor if isinstance(tensor_finder, TensorExtractor) else []
return tensor
Loading