Skip to content

Commit

Permalink
Feature/dg 979 support classification (#149)
Browse files Browse the repository at this point in the history
* Classification support (WIP)

* Added logging of the event if feature extractor failed

* Fixing summary report for classification

* Remove default value batches_early_stop for ClassificationAnalysisManager

* Remove default value batches_early_stop for ClassificationAnalysisManager

* Remove default value batches_early_stop for ClassificationAnalysisManager

* Support dataset

* Copy-paste bugfix

* New feature extractor ClassificationClassDistributionVsArea

* Change x axis to use image size instead of image area

* Added action points to description

* Added action points to description

* Added action points to description

* Fix PR

* Added normalization to handle case when images were normalized with some unknown mean/std

* Copy implementation of jupyter_ui_poll to DG

* Added end2end test

* Added end2end test

* Update master

* Added warning

---------

Co-authored-by: Shay Aharon <80472096+shaydeci@users.noreply.github.com>
  • Loading branch information
BloodAxe and shaydeci authored Jul 19, 2023
1 parent 3fa26ca commit 08f406c
Show file tree
Hide file tree
Showing 28 changed files with 1,967 additions and 1 deletion.
242 changes: 242 additions & 0 deletions examples/classification_torchvision_caltech101.ipynb

Large diffs are not rendered by default.

232 changes: 232 additions & 0 deletions examples/classification_torchvision_fashion_mnist.ipynb

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions src/data_gradients/assets/html/basic_info_fe_classification.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<table align="center" border="0" cellpadding="1" cellspacing="1" style="width:800px">
<thead>
<tr>
<th scope="col" style="column-width: 300px;">
<h2>&nbsp;</h2>
</th>
<th scope="col" class="train_header">
<strong>Train</strong>
</th>
<th scope="col" class="val_header">
<strong>Validation</strong>
</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align:left; color:black;">Images</td>
<td class="train_header"><strong>{{train.num_samples}}</strong></td>
<td class="val_header"><strong>{{val.num_samples}}</strong></td>
</tr>
<tr>
<td style="text-align:left; color:black;">Classes</td>
<td class="train_header"><strong>{{train.classes_count}}</strong></td>
<td class="val_header"><strong>{{val.classes_count}}</strong></td>
</tr>
<tr>
<td style="text-align:left; color:black;">Classes in use</td>
<td class="train_header"><strong>{{train.classes_in_use}}</strong></td>
<td class="val_header"><strong>{{val.classes_in_use}}</strong></td>
</tr>
<tr>
<td style="text-align:left; color:black;">Median image resolution</td>
<td class="train_text"><strong>{{train.med_image_resolution}}</strong></td>
<td class="val_text"><strong>{{val.med_image_resolution}}</strong></td>
</tr>
</tbody>
</table>
25 changes: 25 additions & 0 deletions src/data_gradients/batch_processors/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List

from data_gradients.batch_processors.base import BatchProcessor
from data_gradients.batch_processors.formatters.classification import ClassificationBatchFormatter
from data_gradients.batch_processors.output_mapper.dataset_output_mapper import DatasetOutputMapper
from data_gradients.batch_processors.preprocessors.classification import ClassificationBatchPreprocessor
from data_gradients.config.data.data_config import ClassificationDataConfig


class ClassificationBatchProcessor(BatchProcessor):
def __init__(
self,
*,
data_config: ClassificationDataConfig,
class_names: List[str],
class_names_to_use: List[str],
n_image_channels: int = 3,
):
dataset_adapter = DatasetOutputMapper(data_config=data_config)
formatter = ClassificationBatchFormatter(
data_config=data_config, class_names=class_names, class_names_to_use=class_names_to_use, n_image_channels=n_image_channels
)
preprocessor = ClassificationBatchPreprocessor(class_names=class_names, n_image_channels=n_image_channels)

super().__init__(dataset_output_mapper=dataset_adapter, batch_formatter=formatter, batch_preprocessor=preprocessor)
83 changes: 83 additions & 0 deletions src/data_gradients/batch_processors/formatters/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import warnings
from typing import Tuple, List

import torch
from torch import Tensor

from data_gradients.batch_processors.formatters.base import BatchFormatter
from data_gradients.batch_processors.formatters.utils import DatasetFormatError, check_images_shape
from data_gradients.batch_processors.formatters.utils import ensure_channel_first
from data_gradients.config.data.data_config import ClassificationDataConfig


class UnsupportedClassificationBatchFormatError(DatasetFormatError):
def __init__(self, str):
super().__init__(str)


class ClassificationBatchFormatter(BatchFormatter):
"""Classification formatter class"""

def __init__(
self,
data_config: ClassificationDataConfig,
class_names: List[str],
class_names_to_use: List[str],
n_image_channels: int,
):
"""
:param class_names: List of all class names in the dataset. The index should represent the class_id.
:param class_names_to_use: List of class names that we should use for analysis.
:param n_image_channels: Number of image channels (3 for RGB, 1 for Gray Scale, ...)
"""
self.data_config = data_config

class_names_to_use = set(class_names_to_use)
self.class_ids_to_use = [class_id for class_id, class_name in enumerate(class_names) if class_name in class_names_to_use]

self.n_image_channels = n_image_channels

def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
"""Validate batch images and labels format, and ensure that they are in the relevant format for detection.
:param images: Batch of images, in (BS, ...) format
:param labels: Batch of labels, in (BS) format
:return:
- images: Batch of images already formatted into (BS, C, H, W)
- labels: Batch of targets (BS)
"""

images = ensure_channel_first(images, n_image_channels=self.n_image_channels)
images = check_images_shape(images, n_image_channels=self.n_image_channels)
labels = self.ensure_labels_shape(images=images, labels=labels)

if 0 <= images.min() and images.max() <= 1:
images *= 255
images = images.to(torch.uint8)
elif images.min() < 0: # images were normalized with some unknown mean and std
images -= images.min()
images /= images.max()
images *= 255
images = images.to(torch.uint8)

warnings.warn(
"Images were normalized with some unknown mean and std. "
"For visualization needs and color distribution plots Data Gradients will try to scale them to [0, 255] range. "
"This normalization will use min-max scaling per batch with may make the images look brighter/darker than they should be. "
)

return images, labels

@staticmethod
def ensure_labels_shape(labels: Tensor, images: Tensor) -> Tensor:
"""Make sure that the labels have the correct shape, i.e. (BS)."""
if torch.is_floating_point(labels):
raise UnsupportedClassificationBatchFormatError("Labels should be integers")

if labels.ndim != 1:
raise UnsupportedClassificationBatchFormatError("Labels should be 1D tensor")

if len(labels) != len(images):
raise UnsupportedClassificationBatchFormatError("Labels and images should have the same length")

return labels
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Iterable, List
from torch import Tensor
import numpy as np
import time

from data_gradients.utils.data_classes import DetectionSample
from data_gradients.batch_processors.preprocessors.base import BatchPreprocessor
from data_gradients.utils.data_classes.data_samples import ImageChannelFormat, ClassificationSample


class ClassificationBatchPreprocessor(BatchPreprocessor):
def __init__(self, class_names: List[str], n_image_channels:int):
"""
:param class_names: List of all class names in the dataset. The index should represent the class_id.
"""
if n_image_channels not in [1, 3]:
raise ValueError(f"n_image_channels should be either 1 or 3, but got {n_image_channels}")
self.class_names = class_names
self.n_image_channels = n_image_channels

def preprocess(self, images: Tensor, labels: Tensor, split: str) -> Iterable[DetectionSample]:
"""Group batch images and labels into a single ready-to-analyze batch object, including all relevant preprocessing.
:param images: Batch of images already formatted into (BS, C, H, W)
:param labels: Batch of targets (BS)
:param split: Name of the split (train, val, test)
:return: Iterable of ready to analyse detection samples.
"""
images = np.uint8(np.transpose(images.cpu().numpy(), (0, 2, 3, 1)))

# TODO: image_format is hard-coded here, but it should be refactored afterwards
image_format = {1: ImageChannelFormat.GRAYSCALE, 3: ImageChannelFormat.RGB}[self.n_image_channels]

for image, target in zip(images, labels):
class_id = int(target)

sample = ClassificationSample(
image=image,
class_id=class_id,
class_names=self.class_names,
split=split,
image_format=image_format,
sample_id=None,
)
sample.sample_id = str(id(sample))
yield sample
12 changes: 12 additions & 0 deletions src/data_gradients/config/classification.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
report_sections:
- name: Image Features
features:
- ClassificationSummaryStats
- ImagesResolution
- ImageColorDistribution
- ImagesAverageBrightness
- name: Classification Features
features:
- ClassificationClassFrequency
- ClassificationClassDistributionVsArea
# - ClassificationClassDistributionVsAreaPlot
5 changes: 5 additions & 0 deletions src/data_gradients/config/data/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def get_labels_extractor(self, question: Optional[Question] = None, hint: str =
return TensorExtractorResolver.to_callable(tensor_extractor=self.labels_extractor)


@dataclass
class ClassificationDataConfig(DataConfig):
pass


@dataclass
class SegmentationDataConfig(DataConfig):
pass
Expand Down
79 changes: 78 additions & 1 deletion src/data_gradients/config/data/questions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from time import sleep
from typing import Dict, Any, Optional, List

from data_gradients.utils.utils import text_to_blue, text_to_yellow
Expand Down Expand Up @@ -28,18 +29,37 @@ def ask_question(question: Optional[Question], hint: str = "") -> Any:
return question.options[answer]


def ask_user(main_question: str, options: List[str], optional_description: str = "") -> str:
def is_notebook() -> bool:
try:
from IPython import get_ipython

shell = get_ipython().__class__.__name__
if shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
elif shell == "TerminalInteractiveShell":
return False # Terminal running IPython
else:
return False # Other type (?)
except ImportError:
return False
except NameError:
return False # Probably standard Python interpreter


def ask_user_via_stdin(main_question: str, options: List[str], optional_description: str = "") -> str:
"""Prompt the user to choose an option from a list of options.
:param main_question: The main question or instruction for the user.
:param options: List of options to chose from.
:param optional_description: Optional description to display to the user.
:return: The chosen option (key from the options_described dictionary).
"""

numbers_to_chose_from = range(len(options))

options_formatted = "\n".join([f"[{text_to_blue(number)}] | {option_description}" for number, option_description in zip(numbers_to_chose_from, options)])

user_answer = None

while user_answer not in numbers_to_chose_from:
print("\n------------------------------------------------------------------------")
print(f"{main_question}")
Expand All @@ -62,4 +82,61 @@ def ask_user(main_question: str, options: List[str], optional_description: str =
selected_option = options[user_answer]
print(f"Great! You chose: {text_to_yellow(selected_option)}\n")


def ask_user_via_jupyter(main_question: str, options: List[str], optional_description: str = "") -> str:
numbers_to_chose_from = range(len(options))

options_formatted = "\n".join([f"[{text_to_blue(number)}] | {option_description}" for number, option_description in zip(numbers_to_chose_from, options)])

user_answer = None

print("\n------------------------------------------------------------------------")
print(f"{main_question}")
print("------------------------------------------------------------------------")
if optional_description:
print(optional_description)
print("\nOptions:")
print(options_formatted)
print("")

import ipywidgets as widgets
from IPython.display import display
from data_gradients.utils.jupyter_utils import ui_events

for i, option in enumerate(options):
button = widgets.Button(description=option)
button.value = i
output = widgets.Output()

display(button, output)

def on_button_clicked(b):
with output:
nonlocal user_answer
user_answer = b.value
print("You selected option: " + b.value)

button.on_click(on_button_clicked)

with ui_events() as poll:
while user_answer is None:
poll(10)

selected_option = options[user_answer]
print(f"Great! You chose: {text_to_yellow(selected_option)}\n")
return selected_option


def ask_user(main_question: str, options: List[str], optional_description: str = "") -> str:
"""Prompt the user to choose an option from a list of options.
Depending on the environment, the user will be prompted via stdin or via a Jupyter widget.
:param main_question: The main question or instruction for the user.
:param options: List of options to chose from.
:param optional_description: Optional description to display to the user.
:return: The chosen option (key from the options_described dictionary).
"""

if is_notebook():
return ask_user_via_jupyter(main_question, options, optional_description)
else:
return ask_user_via_stdin(main_question, options, optional_description)
10 changes: 10 additions & 0 deletions src/data_gradients/feature_extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
DetectionSampleVisualization,
DetectionBoundingBoxIoU,
)
from .classification import (
ClassificationClassFrequency,
ClassificationSummaryStats,
ClassificationClassDistributionVsArea,
ClassificationClassDistributionVsAreaPlot,
)

__all__ = [
"ImageDuplicates",
Expand All @@ -46,4 +52,8 @@
"DetectionClassesPerImageCount",
"DetectionSampleVisualization",
"DetectionBoundingBoxIoU",
"ClassificationClassFrequency",
"ClassificationSummaryStats",
"ClassificationClassDistributionVsArea",
"ClassificationClassDistributionVsAreaPlot"
]
11 changes: 11 additions & 0 deletions src/data_gradients/feature_extractors/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .class_frequency import ClassificationClassFrequency
from .summary import ClassificationSummaryStats
from .class_distribution_vs_area import ClassificationClassDistributionVsArea
from .class_distribution_vs_area_scatter import ClassificationClassDistributionVsAreaPlot

__all__ = [
"ClassificationClassFrequency",
"ClassificationSummaryStats",
"ClassificationClassDistributionVsArea",
"ClassificationClassDistributionVsAreaPlot"
]
Loading

0 comments on commit 08f406c

Please sign in to comment.