Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/dg 979 support classification #149

Merged
merged 30 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ce61298
Classification support (WIP)
BloodAxe Jul 10, 2023
e5c8e57
Added logging of the event if feature extractor failed
BloodAxe Jul 10, 2023
6ef4031
Merge branch 'feature/DG-000-add-logging' into feature/DG-979-support…
BloodAxe Jul 10, 2023
a115d3b
Fixing summary report for classification
BloodAxe Jul 10, 2023
7d47805
Merge branch 'master' into feature/DG-979-support-classification
BloodAxe Jul 10, 2023
d385a22
Remove default value batches_early_stop for ClassificationAnalysisMan…
BloodAxe Jul 10, 2023
03d458a
Remove default value batches_early_stop for ClassificationAnalysisMan…
BloodAxe Jul 10, 2023
c48151d
Remove default value batches_early_stop for ClassificationAnalysisMan…
BloodAxe Jul 10, 2023
c7dc25a
Merge branch 'master' into feature/DG-979-support-classification
shaydeci Jul 10, 2023
da94264
Support dataset
BloodAxe Jul 10, 2023
b1cf524
Merge remote-tracking branch 'origin/feature/DG-979-support-classific…
BloodAxe Jul 10, 2023
f70b6c3
Copy-paste bugfix
BloodAxe Jul 10, 2023
1bd963b
New feature extractor ClassificationClassDistributionVsArea
BloodAxe Jul 11, 2023
35256b8
Change x axis to use image size instead of image area
BloodAxe Jul 11, 2023
fcc8966
Added action points to description
BloodAxe Jul 11, 2023
26ba285
Added action points to description
BloodAxe Jul 11, 2023
8530cb2
Added action points to description
BloodAxe Jul 11, 2023
927dac6
Merge branch 'master' into feature/DG-979-support-classification
BloodAxe Jul 11, 2023
4e052aa
Fix PR
BloodAxe Jul 11, 2023
43c4f16
Added normalization to handle case when images were normalized with s…
BloodAxe Jul 12, 2023
ee806ef
Merge branch 'master' into feature/DG-979-support-classification
BloodAxe Jul 12, 2023
6693126
Copy implementation of jupyter_ui_poll to DG
BloodAxe Jul 12, 2023
b228a67
Merge remote-tracking branch 'origin/feature/DG-979-support-classific…
BloodAxe Jul 12, 2023
2cf8765
Added end2end test
BloodAxe Jul 12, 2023
c3eac5c
Added end2end test
BloodAxe Jul 12, 2023
717b214
Merge branch 'master' into feature/DG-979-support-classification
BloodAxe Jul 13, 2023
1e504c8
Merge branch 'master' into feature/DG-979-support-classification
BloodAxe Jul 14, 2023
949b018
Merge branch 'master' into feature/DG-979-support-classification
BloodAxe Jul 17, 2023
bccba58
Update master
BloodAxe Jul 17, 2023
015ee42
Added warning
BloodAxe Jul 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 247 additions & 0 deletions examples/classification_torchvision_caltech101.ipynb

Large diffs are not rendered by default.

235 changes: 235 additions & 0 deletions examples/classification_torchvision_fashion_mnist.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ seaborn
xhtml2pdf
jinja2
imagededup
jupyter-ui-poll==0.2.2
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
37 changes: 37 additions & 0 deletions src/data_gradients/assets/html/basic_info_fe_classification.html
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<table align="center" border="0" cellpadding="1" cellspacing="1" style="width:800px">
<thead>
<tr>
<th scope="col" style="column-width: 300px;">
<h2>&nbsp;</h2>
</th>
<th scope="col" class="train_header">
<strong>Train</strong>
</th>
<th scope="col" class="val_header">
<strong>Validation</strong>
</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align:left; color:black;">Images</td>
<td class="train_header"><strong>{{train.num_samples}}</strong></td>
<td class="val_header"><strong>{{val.num_samples}}</strong></td>
</tr>
<tr>
<td style="text-align:left; color:black;">Classes</td>
<td class="train_header"><strong>{{train.classes_count}}</strong></td>
<td class="val_header"><strong>{{val.classes_count}}</strong></td>
</tr>
<tr>
<td style="text-align:left; color:black;">Classes in use</td>
<td class="train_header"><strong>{{train.classes_in_use}}</strong></td>
<td class="val_header"><strong>{{val.classes_in_use}}</strong></td>
</tr>
<tr>
<td style="text-align:left; color:black;">Median image resolution</td>
<td class="train_text"><strong>{{train.med_image_resolution}}</strong></td>
<td class="val_text"><strong>{{val.med_image_resolution}}</strong></td>
</tr>
</tbody>
</table>
25 changes: 25 additions & 0 deletions src/data_gradients/batch_processors/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List

from data_gradients.batch_processors.adapters.dataset_adapter import DatasetAdapter
from data_gradients.batch_processors.base import BatchProcessor
from data_gradients.batch_processors.formatters.classification import ClassificationBatchFormatter
from data_gradients.batch_processors.preprocessors.classification import ClassificationBatchPreprocessor
from data_gradients.config.data.data_config import ClassificationDataConfig


class ClassificationBatchProcessor(BatchProcessor):
def __init__(
self,
*,
data_config: ClassificationDataConfig,
class_names: List[str],
class_names_to_use: List[str],
n_image_channels: int = 3,
):
dataset_adapter = DatasetAdapter(data_config=data_config)
formatter = ClassificationBatchFormatter(
data_config=data_config, class_names=class_names, class_names_to_use=class_names_to_use, n_image_channels=n_image_channels
)
preprocessor = ClassificationBatchPreprocessor(class_names=class_names, n_image_channels=n_image_channels)

super().__init__(dataset_adapter=dataset_adapter, batch_formatter=formatter, batch_preprocessor=preprocessor)
72 changes: 72 additions & 0 deletions src/data_gradients/batch_processors/formatters/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Tuple, Optional, Callable, List

import torch
from torch import Tensor

from data_gradients.batch_processors.utils import check_all_integers
from data_gradients.batch_processors.formatters.base import BatchFormatter
from data_gradients.batch_processors.formatters.utils import ensure_images_shape, ensure_channel_first, drop_nan
from data_gradients.config.data.data_config import DetectionDataConfig, ClassificationDataConfig
from data_gradients.batch_processors.formatters.utils import DatasetFormatError


class UnsupportedClassificationBatchFormatError(DatasetFormatError):
def __init__(self, str):
super().__init__(str)


class ClassificationBatchFormatter(BatchFormatter):
"""Classification formatter class"""

def __init__(
self,
data_config: ClassificationDataConfig,
class_names: List[str],
class_names_to_use: List[str],
n_image_channels: int,
):
"""
:param class_names: List of all class names in the dataset. The index should represent the class_id.
:param class_names_to_use: List of class names that we should use for analysis.
:param n_image_channels: Number of image channels (3 for RGB, 1 for Gray Scale, ...)
"""
self.data_config = data_config

class_names_to_use = set(class_names_to_use)
self.class_ids_to_use = [class_id for class_id, class_name in enumerate(class_names) if class_name in class_names_to_use]

self.n_image_channels = n_image_channels

def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
"""Validate batch images and labels format, and ensure that they are in the relevant format for detection.

:param images: Batch of images, in (BS, ...) format
:param labels: Batch of labels, in (BS) format
:return:
- images: Batch of images already formatted into (BS, C, H, W)
- labels: Batch of targets (BS)
"""

images = ensure_channel_first(images, n_image_channels=self.n_image_channels)
images = ensure_images_shape(images, n_image_channels=self.n_image_channels)
labels = self.ensure_labels_shape(images=images, labels=labels)

if 0 <= images.min() and images.max() <= 1:
images *= 255
images = images.to(torch.uint8)

BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
return images, labels

@staticmethod
def ensure_labels_shape(labels: Tensor, images: Tensor) -> Tensor:
"""Make sure that the labels have the correct shape, i.e. (BS)."""
if torch.is_floating_point(labels):
raise UnsupportedClassificationBatchFormatError("Labels should be integers")

if labels.ndim != 1:
raise UnsupportedClassificationBatchFormatError("Labels should be 1D tensor")

if len(labels) != len(images):
raise UnsupportedClassificationBatchFormatError("Labels and images should have the same length")

return labels
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Iterable, List
from torch import Tensor
import numpy as np
import time

from data_gradients.utils.data_classes import DetectionSample
from data_gradients.batch_processors.preprocessors.base import BatchPreprocessor
from data_gradients.utils.data_classes.data_samples import ImageChannelFormat, ClassificationSample


class ClassificationBatchPreprocessor(BatchPreprocessor):
def __init__(self, class_names: List[str], n_image_channels:int):
"""
:param class_names: List of all class names in the dataset. The index should represent the class_id.
"""
if n_image_channels not in [1, 3]:
raise ValueError(f"n_image_channels should be either 1 or 3, but got {n_image_channels}")
self.class_names = class_names
self.n_image_channels = n_image_channels

def preprocess(self, images: Tensor, labels: Tensor, split: str) -> Iterable[DetectionSample]:
"""Group batch images and labels into a single ready-to-analyze batch object, including all relevant preprocessing.

:param images: Batch of images already formatted into (BS, C, H, W)
:param labels: Batch of targets (BS)
:param split: Name of the split (train, val, test)
:return: Iterable of ready to analyse detection samples.
"""
images = np.uint8(np.transpose(images.cpu().numpy(), (0, 2, 3, 1)))

# TODO: image_format is hard-coded here, but it should be refactored afterwards
image_format = {1: ImageChannelFormat.GRAYSCALE, 3: ImageChannelFormat.RGB}[self.n_image_channels]

for image, target in zip(images, labels):
class_id = int(target)

sample = ClassificationSample(
image=image,
class_id=class_id,
class_names=self.class_names,
split=split,
image_format=image_format,
sample_id=None,
)
sample.sample_id = str(id(sample))
yield sample
12 changes: 12 additions & 0 deletions src/data_gradients/config/classification.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
report_sections:
- name: Image Features
features:
- ClassificationSummaryStats
- ImagesResolution
- ImageColorDistribution
- ImagesAverageBrightness
- name: Classification Features
features:
- ClassificationClassFrequency
- ClassificationClassDistributionVsArea
# - ClassificationClassDistributionVsAreaPlot
5 changes: 5 additions & 0 deletions src/data_gradients/config/data/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def get_labels_extractor(self, question: Optional[Question] = None, hint: str =
return TensorExtractorResolver.to_callable(tensor_extractor=self.labels_extractor)


@dataclass
class ClassificationDataConfig(DataConfig):
pass


@dataclass
class SegmentationDataConfig(DataConfig):
pass
Expand Down
79 changes: 78 additions & 1 deletion src/data_gradients/config/data/questions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from time import sleep
from typing import Dict, Any, Optional, List


Expand Down Expand Up @@ -34,18 +35,37 @@ def ask_question(question: Optional[Question], hint: str = "") -> Any:
return question.options[answer]


def ask_user(main_question: str, options: List[str], optional_description: str = "") -> str:
def is_notebook() -> bool:
try:
from IPython import get_ipython

shell = get_ipython().__class__.__name__
if shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
elif shell == "TerminalInteractiveShell":
return False # Terminal running IPython
else:
return False # Other type (?)
except ImportError:
return False
except NameError:
return False # Probably standard Python interpreter


def ask_user_via_stdin(main_question: str, options: List[str], optional_description: str = "") -> str:
"""Prompt the user to choose an option from a list of options.
:param main_question: The main question or instruction for the user.
:param options: List of options to chose from.
:param optional_description: Optional description to display to the user.
:return: The chosen option (key from the options_described dictionary).
"""

numbers_to_chose_from = range(len(options))

options_formatted = "\n".join([f"[{text_to_blue(number)}] | {option_description}" for number, option_description in zip(numbers_to_chose_from, options)])

user_answer = None

while user_answer not in numbers_to_chose_from:
print("\n------------------------------------------------------------------------")
print(f"{main_question}")
Expand All @@ -68,4 +88,61 @@ def ask_user(main_question: str, options: List[str], optional_description: str =
selected_option = options[user_answer]
print(f"Great! You chose: {text_to_yellow(selected_option)}\n")


def ask_user_via_jupyter(main_question: str, options: List[str], optional_description: str = "") -> str:
numbers_to_chose_from = range(len(options))

options_formatted = "\n".join([f"[{text_to_blue(number)}] | {option_description}" for number, option_description in zip(numbers_to_chose_from, options)])

user_answer = None

print("\n------------------------------------------------------------------------")
print(f"{main_question}")
print("------------------------------------------------------------------------")
if optional_description:
print(optional_description)
print("\nOptions:")
print(options_formatted)
print("")

import ipywidgets as widgets
from IPython.display import display
from jupyter_ui_poll import ui_events

for i, option in enumerate(options):
button = widgets.Button(description=option)
button.value = i
output = widgets.Output()

display(button, output)

def on_button_clicked(b):
with output:
nonlocal user_answer
user_answer = b.value
print("You selected option: " + b.value)

button.on_click(on_button_clicked)

with ui_events() as poll:
while user_answer is None:
poll(10)

selected_option = options[user_answer]
print(f"Great! You chose: {text_to_yellow(selected_option)}\n")
return selected_option


def ask_user(main_question: str, options: List[str], optional_description: str = "") -> str:
"""Prompt the user to choose an option from a list of options.
Depending on the environment, the user will be prompted via stdin or via a Jupyter widget.
:param main_question: The main question or instruction for the user.
:param options: List of options to chose from.
:param optional_description: Optional description to display to the user.
:return: The chosen option (key from the options_described dictionary).
"""

if is_notebook():
return ask_user_via_jupyter(main_question, options, optional_description)
else:
return ask_user_via_stdin(main_question, options, optional_description)
10 changes: 10 additions & 0 deletions src/data_gradients/feature_extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
DetectionSampleVisualization,
DetectionBoundingBoxIoU,
)
from .classification import (
ClassificationClassFrequency,
ClassificationSummaryStats,
ClassificationClassDistributionVsArea,
ClassificationClassDistributionVsAreaPlot,
)

__all__ = [
"ImageDuplicates",
Expand All @@ -46,4 +52,8 @@
"DetectionClassesPerImageCount",
"DetectionSampleVisualization",
"DetectionBoundingBoxIoU",
"ClassificationClassFrequency",
"ClassificationSummaryStats",
"ClassificationClassDistributionVsArea",
"ClassificationClassDistributionVsAreaPlot"
]
11 changes: 11 additions & 0 deletions src/data_gradients/feature_extractors/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .class_frequency import ClassificationClassFrequency
from .summary import ClassificationSummaryStats
from .class_distribution_vs_area import ClassificationClassDistributionVsArea
from .class_distribution_vs_area_scatter import ClassificationClassDistributionVsAreaPlot

__all__ = [
"ClassificationClassFrequency",
"ClassificationSummaryStats",
"ClassificationClassDistributionVsArea",
"ClassificationClassDistributionVsAreaPlot"
]
Loading