Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add custom transformations #309

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions keras_preprocessing/image/image_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class ImageDataGenerator(object):
(strictly between 0 and 1).
interpolation_order: int, order to use for
the spline interpolation. Higher is slower.
list_of_custom_transformations: list, list of transformation functions
dtype: Dtype to use for the generated arrays.

# Examples
Expand Down Expand Up @@ -274,6 +275,7 @@ def __init__(self,
data_format='channels_last',
validation_split=0.0,
interpolation_order=1,
list_of_custom_transformations=None,
dtype='float32'):

self.featurewise_center = featurewise_center
Expand All @@ -296,6 +298,7 @@ def __init__(self,
self.preprocessing_function = preprocessing_function
self.dtype = dtype
self.interpolation_order = interpolation_order
self.list_of_custom_transformations = list_of_custom_transformations

if data_format not in {'channels_last', 'channels_first'}:
raise ValueError(
Expand Down Expand Up @@ -892,6 +895,11 @@ def apply_transform(self, x, transform_parameters):
if transform_parameters.get('brightness') is not None:
x = apply_brightness_shift(x, transform_parameters['brightness'])

# Custom Transformations
if self.list_of_custom_transformations:
for transformation in self.list_of_custom_transformations:
x = transformation(x)

return x

def random_transform(self, x, seed=None):
Expand Down