Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚀 Add PreProcessor to AnomalyModule #2358

Open
wants to merge 60 commits into
base: feature/v2
Choose a base branch
from

Conversation

samet-akcay
Copy link
Contributor

@samet-akcay samet-akcay commented Oct 9, 2024

📝 Description

The PreProcessor class serves as both a PyTorch module and a Lightning callback, handling transforms during different stages of training, validation, testing and prediction. This PR demonstrates how to create and use custom pre-processors.

Key Components

The pre-processor functionality is implemented in:

class PreProcessor(nn.Module, Callback):
    """Anomalib pre-processor.

    This class serves as both a PyTorch module and a Lightning callback, handling
    the application of transforms to data batches during different stages of
    training, validation, testing, and prediction.

    Args:
        train_transform (Transform | None): Transform to apply during training.
        val_transform (Transform | None): Transform to apply during validation.
        test_transform (Transform | None): Transform to apply during testing.
        transform (Transform | None): General transform to apply if stage-specific
            transforms are not provided.

    Raises:
        ValueError: If both `transform` and any of the stage-specific transforms
            are provided simultaneously.

    Notes:
        If only `transform` is provided, it will be used for all stages (train, val, test).

        Priority of transforms:
        1. Explicitly set PreProcessor transforms (highest priority)
        2. Datamodule transforms (if PreProcessor has no transforms)
        3. Dataloader transforms (if neither PreProcessor nor datamodule have transforms)
        4. Default transforms (lowest priority)

    Examples:
        >>> from torchvision.transforms.v2 import Compose, Resize, ToTensor
        >>> from anomalib.pre_processing import PreProcessor

        >>> # Define transforms
        >>> train_transform = Compose([Resize((224, 224)), ToTensor()])
        >>> val_transform = Compose([Resize((256, 256)), CenterCrop((224, 224)), ToTensor()])

        >>> # Create PreProcessor with stage-specific transforms
        >>> pre_processor = PreProcessor(
        ...     train_transform=train_transform,
        ...     val_transform=val_transform
        ... )

        >>> # Create PreProcessor with a single transform for all stages
        >>> common_transform = Compose([Resize((224, 224)), ToTensor()])
        >>> pre_processor_common = PreProcessor(transform=common_transform)

        >>> # Use in a Lightning module
        >>> class MyModel(LightningModule):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.pre_processor = PreProcessor(...)
        ...
        ...     def configure_callbacks(self):
        ...         return [self.pre_processor]
        ...
        ...     def training_step(self, batch, batch_idx):
        ...         # The pre_processor will automatically apply the correct transform
        ...         processed_batch = self.pre_processor(batch)
        ...         # Rest of the training step
    """

And used by the base AnomalyModule in:

    def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor:
        """Resolve and validate which pre-processor to use..

        Args:
            pre_processor: Pre-processor configuration
                - True -> use default pre-processor
                - False -> no pre-processor
                - PreProcessor -> use the provided pre-processor

        Returns:
            Configured pre-processor
        """
        if isinstance(pre_processor, PreProcessor):
            return pre_processor
        if isinstance(pre_processor, bool):
            return self.configure_pre_processor()
        msg = f"Invalid pre-processor type: {type(pre_processor)}"
        raise TypeError(msg)

Usage Examples

1. Using Default Pre-Processor

The simplest way is to use the default pre-processor which resizes images to 256x256 and normalizes using ImageNet statistics:

from anomalib.models import PatchCore

# Uses default pre-processor
model = PatchCore()

2. Custom Pre-Processor with Different Transforms

Create a pre-processor with custom transforms for different stages:

from torchvision.transforms.v2 import Compose, Resize, CenterCrop, RandomHorizontalFlip, Normalize
from anomalib.pre_processing import PreProcessor

# Define stage-specific transforms
train_transform = Compose([
    Resize((256, 256), antialias=True),
    RandomHorizontalFlip(p=0.5),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = Compose([
    Resize((256, 256), antialias=True), 
    CenterCrop((224, 224)),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create pre-processor with different transforms per stage
pre_processor = PreProcessor(
    train_transform=train_transform,
    val_transform=val_transform,
    test_transform=val_transform  # Use same transform as validation for testing
)

# Use custom pre-processor in model
model = PatchCore(pre_processor=pre_processor)

3. Disable Pre-Processing

To disable pre-processing entirely:

model = PatchCore(pre_processor=False)

4. Override Default Pre-Processor in Custom Model

Custom models can override the default pre-processor configuration:

from anomalib.models.components.base import AnomalyModule

class CustomModel(AnomalyModule):
    @classmethod
    def configure_pre_processor(cls, image_size=(224, 224)) -> PreProcessor:
        transform = Compose([
            Resize(image_size, antialias=True),
            Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        return PreProcessor(transform=transform)

Notes

  • Pre-processor transforms are applied in order of priority:
    • Explicitly set PreProcessor transforms (highest)
    • Datamodule transforms
    • Dataloader transforms
    • Default transforms (lowest)
  • The pre-processor automatically handles both image and mask transforms during training
  • Custom transforms should maintain compatibility with both image and segmentation mask inputs

Testing

  • Added unit tests to verify:
  • Default pre-processor behavior
  • Custom transform application
  • Transform priority order
  • Mask transformation handling

✨ Changes

Select what type of change your PR is:

  • 🐞 Bug fix (non-breaking change which fixes an issue)
  • 🔨 Refactor (non-breaking change which refactors the code base)
  • 🚀 New feature (non-breaking change which adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📚 Documentation update
  • 🔒 Security update

✅ Checklist

Before you submit your pull request, please make sure you have completed the following steps:

  • 📋 I have summarized my changes in the CHANGELOG and followed the guidelines for my type of change (skip for minor changes, documentation updates, and test enhancements).
  • 📚 I have made the necessary updates to the documentation (if applicable).
  • 🧪 I have written tests that support my changes and prove that my fix is effective or my feature works (if applicable).

For more information about code review checklists, see the Code Review Checklist.

Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
@jpcbertoldo
Copy link
Contributor

A sub-feature request that would fit here: (optionally?) keep both the transformed and original image/mask in the batch.

So instead of

            image, gt_mask = self.XXX_transform(batch.image, batch.gt_mask)
            batch.update(image=image, gt_mask=gt_mask)

something like

            batch.update(image_original=batch.image, gt_mask_original=batch.gt_mask)
            image, gt_mask = self.XXX_transform(batch.image, batch.gt_mask)
            batch.update(image=image, gt_mask=gt_mask)

It's quite practical to have these when using the API (i've re-implemented this in my local copy 100 times haha).

@samet-akcay
Copy link
Contributor Author

A sub-feature request that would fit here: (optionally?) keep both the transformed and original image/mask in the batch.

So instead of

            image, gt_mask = self.XXX_transform(batch.image, batch.gt_mask)
            batch.update(image=image, gt_mask=gt_mask)

something like

            batch.update(image_original=batch.image, gt_mask_original=batch.gt_mask)
            image, gt_mask = self.XXX_transform(batch.image, batch.gt_mask)
            batch.update(image=image, gt_mask=gt_mask)

It's quite practical to have these when using the API (i've re-implemented this in my local copy 100 times haha).

yeah, the idea is to keep batch.image and batch.gt_mask original outside the model. It is not working that way though :)

@jpcbertoldo
Copy link
Contributor

yeah, the idea is to keep batch.image and batch.gt_mask original outside the model

exactly, makes sense : )

but it's also useful to be able to access the transformed one (eg. when using augmentations)

it is not working that way though :)

didnt get this. cause it's not backcompatible?

@samet-akcay
Copy link
Contributor Author

yeah, the idea is to keep batch.image and batch.gt_mask original outside the model

exactly, makes sense : )

but it's also useful to be able to access the transformed one (eg. when using augmentations)

it is not working that way though :)

didnt get this. cause it's not backcompatible?

oh I meant, it is currently not working, I need to fix it :)

Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
…ssor

Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
@openvinotoolkit openvinotoolkit deleted a comment from djdameln Oct 25, 2024
@openvinotoolkit openvinotoolkit deleted a comment from djdameln Oct 25, 2024
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Copy link
Collaborator

@ashwinvaidya17 ashwinvaidya17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have a few minor comments

Comment on lines 58 to 69
# Handle pre-processor
# True -> use default pre-processor
# False -> no pre-processor
# PreProcessor -> use the provided pre-processor
if isinstance(pre_processor, PreProcessor):
self.pre_processor = pre_processor
elif isinstance(pre_processor, bool):
self.pre_processor = self.configure_pre_processor()
else:
msg = f"Invalid pre-processor type: {type(pre_processor)}"
raise TypeError(msg)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment, but can we move this to a separate method?

Copy link
Contributor Author

@samet-akcay samet-akcay Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which one would you prefer? _init_pre_processor, _resolve_pre_processor, _handle_pre_processor or _setup_pre_processor

@@ -220,30 +250,12 @@ def input_size(self) -> tuple[int, int] | None:
The effective input size is the size of the input tensor after the transform has been applied. If the transform
is not set, or if the transform does not change the shape of the input tensor, this method will return None.
"""
transform = self.transform or self.configure_transforms()
transform = self.pre_processor.train_transform
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a check to ascertain whether train_transform is present? Models like VlmAD might not have train_transforms passed to them. I feel it should pick up val or pred transform is train is not available.

@@ -79,6 +93,10 @@ def _setup(self) -> None:
initialization.
"""

def configure_callbacks(self) -> Sequence[Callback] | Callback:
"""Configure default callbacks for AnomalyModule."""
return [self.pre_processor]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we ensure that pre_processor callback is called before the other callbacks? Like, is metrics callback dependent on pre-processing first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the base model, we will need to ensure the list of callbacks, I guess. For the child classes that inherits this one, we could have something like this:

def configure_callbacks(self) -> Sequence[Callback]:
    """Configure callbacks with parent callbacks preserved."""
    # Get parent callbacks first
    parent_callbacks = super().configure_callbacks()
    
    # Add child-specific callbacks
    callbacks = [
        *parent_callbacks,      # Parent callbacks first
        MyCustomCallback(),     # Then child callbacks
        AnotherCallback(),
    ]
    return callbacks

Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
src/anomalib/models/components/base/anomaly_module.py Outdated Show resolved Hide resolved
src/anomalib/models/components/base/anomaly_module.py Outdated Show resolved Hide resolved
src/anomalib/pre_processing/pre_processing.py Outdated Show resolved Hide resolved
src/anomalib/pre_processing/pre_processing.py Outdated Show resolved Hide resolved
src/anomalib/pre_processing/pre_processing.py Outdated Show resolved Hide resolved
if dataset_attr and hasattr(datamodule, dataset_attr):
dataset = getattr(datamodule, dataset_attr)
if hasattr(dataset, "transform"):
dataset.transform = transform
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there be a way to assign the transforms to the datamodule before the datasets are instantiated, instead of overwriting them here? That way the datasets would always have the right transform, which would be less prone to bugs.

I guess it is done this way because setup callback hook gets called after the datamodule's setup hook, so by the time that we set up the pre-processor, the datasets are already created. I just wonder if there is a different way to do it which does not involve overwriting the transforms.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR #2239, the setup logic completely changes. It might be an idea to visit it when working on this PR instead of addressing here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we target the other PR to the feature branch then?

Copy link
Contributor Author

@samet-akcay samet-akcay Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to, but it requires quite some changes, as it is out-of-date now
It is within 2.0 requirements #2364

Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

📋 [TASK] Integrate Pre-processing as AnomalibModule Attribute
4 participants