Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add normalize to image transforms module #19544

Merged
merged 219 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
219 commits
Select commit Hold shift + click to select a range
a94c537
Adapt FE methods to transforms library
amyeroberts Jul 27, 2022
932f291
Mixin for saving the image processor
amyeroberts Jul 27, 2022
54aed8b
Base processor skeleton
amyeroberts Jul 27, 2022
ba55c89
BatchFeature for packaging image processor outputs
amyeroberts Jul 27, 2022
4b430d4
Initial image processor for GLPN
amyeroberts Jul 27, 2022
b1c8b59
REmove accidental import
amyeroberts Jul 27, 2022
daf069a
Fixup and docs
amyeroberts Jul 28, 2022
95b4a6a
Mixin for saving the image processor
amyeroberts Jul 27, 2022
6f7ef56
Fixup and docs
amyeroberts Jul 28, 2022
b9ce4a0
Import BatchFeature from feature_extraction_utils
amyeroberts Jul 28, 2022
f02ae6a
Merge branch 'image-processor-mixin' of github.com:amyeroberts/transf…
amyeroberts Jul 28, 2022
6b678fb
Fixup and docs
amyeroberts Jul 28, 2022
db93437
Fixup and docs
amyeroberts Jul 28, 2022
bd890d5
Fixup and docs
amyeroberts Jul 28, 2022
4b27a34
Fixup and docs
amyeroberts Jul 28, 2022
ff0d49e
BatchFeature for packaging image processor outputs
amyeroberts Jul 27, 2022
2c2fa9a
Import BatchFeature from feature_extraction_utils
amyeroberts Jul 28, 2022
b9f7837
Merge branch 'image-processor-mixin' into base-image-processor-class
amyeroberts Jul 28, 2022
346270d
Resolve conflicts
amyeroberts Jul 28, 2022
7faf2e6
Import BatchFeature from feature_extraction_utils
amyeroberts Jul 28, 2022
ccc15fb
Fixup and docs
amyeroberts Jul 28, 2022
c8f8eb6
Fixup and docs
amyeroberts Jul 28, 2022
90093f4
BatchFeature for packaging image processor outputs
amyeroberts Jul 27, 2022
d89c051
Import BatchFeature from feature_extraction_utils
amyeroberts Jul 28, 2022
9bc9157
Fixup and docs
amyeroberts Jul 28, 2022
6ec382a
Mixin for saving the image processor
amyeroberts Jul 27, 2022
56ee6ad
Fixup and docs
amyeroberts Jul 28, 2022
38ebb50
Merge branch 'image-batch-feature' into image-processor-glpn
amyeroberts Jul 28, 2022
6b88d5f
Add rescale back and remove ImageType
amyeroberts Jul 28, 2022
67077f1
fix import mistake
amyeroberts Jul 28, 2022
82712c7
Fix enum var reference
amyeroberts Jul 28, 2022
71d666d
Merge branch 'image-transforms-library' into image-processor-mixin
amyeroberts Jul 28, 2022
fb6438c
Merge branch 'image-processor-mixin' into base-image-processor-class
amyeroberts Jul 28, 2022
ffe71b6
Merge branch 'base-image-processor-class' into image-batch-feature
amyeroberts Jul 28, 2022
cc480e8
Merge branch 'image-batch-feature' into image-processor-glpn
amyeroberts Jul 28, 2022
b997a98
Can transform and specify image data format
amyeroberts Jul 28, 2022
9106443
Merge branch 'image-transforms-library' into image-processor-mixin
amyeroberts Jul 28, 2022
1b3cf65
Remove redundant function
amyeroberts Jul 28, 2022
2860460
Update reference
amyeroberts Jul 28, 2022
3e1077b
Merge branch 'image-transforms-library' into image-processor-mixin
amyeroberts Jul 28, 2022
4264d1a
Merge branch 'image-processor-mixin' into base-image-processor-class
amyeroberts Jul 28, 2022
fb5dcd6
Merge in branch and remove conflicts
amyeroberts Jul 28, 2022
43f561d
Add in rescaling
amyeroberts Jul 29, 2022
60c56e5
Data format flag for rescale
amyeroberts Jul 29, 2022
9294dbc
Fix typo
amyeroberts Jul 29, 2022
654cf93
Fix dimension check
amyeroberts Jul 29, 2022
1360732
Merge branch 'image-transforms-library' into image-processor-mixin
amyeroberts Jul 29, 2022
936de65
Merge branch 'image-processor-mixin' into base-image-processor-class
amyeroberts Jul 29, 2022
627c048
Merge branch 'base-image-processor-class' into image-batch-feature
amyeroberts Jul 29, 2022
1b64c80
Merge branch 'image-batch-feature' into image-processor-glpn
amyeroberts Jul 29, 2022
88b82e9
Fixes to make IP and FE outputs match
amyeroberts Jul 29, 2022
3ea27aa
Add tests for transforms
amyeroberts Jul 29, 2022
84fdd07
Add test for utils
amyeroberts Jul 29, 2022
10d56b1
Merge branch 'image-transforms-library' into image-processor-mixin
amyeroberts Jul 29, 2022
392e980
Update some docstrings
amyeroberts Aug 2, 2022
2117b94
Merge branch 'image-processor-mixin' into base-image-processor-class
amyeroberts Aug 2, 2022
68de952
Resole merge conflicts
amyeroberts Aug 2, 2022
5208680
Merge branch 'image-batch-feature' into image-processor-glpn
amyeroberts Aug 2, 2022
a28ac88
Make sure in channels last before converting to PIL
amyeroberts Aug 2, 2022
2ead9e5
Merge branch 'image-transforms-library' into image-processor-mixin
amyeroberts Aug 2, 2022
9514d54
Merge branch 'image-processor-mixin' into base-image-processor-class
amyeroberts Aug 2, 2022
8f63b76
Merge branch 'base-image-processor-class' into image-batch-feature
amyeroberts Aug 2, 2022
46a9c74
Merge branch 'image-batch-feature' into image-processor-glpn
amyeroberts Aug 2, 2022
082e4ff
Remove default to numpy batching
amyeroberts Aug 2, 2022
bf73358
Fix up
amyeroberts Aug 3, 2022
34b6b2f
Add docstring and model_input_types
amyeroberts Aug 4, 2022
7150293
Use feature processor config from hub
amyeroberts Aug 4, 2022
8678c13
Merge branch 'image-processor-mixin' into base-image-processor-class
amyeroberts Aug 4, 2022
937884c
Merge branch 'base-image-processor-class' into image-batch-feature
amyeroberts Aug 4, 2022
a1b681a
Merge branch 'image-batch-feature' into image-processor-glpn
amyeroberts Aug 4, 2022
b1db434
Alias GLPN feature extractor to image processor
amyeroberts Aug 4, 2022
f0c14ee
Alias feature extractor mixin
amyeroberts Aug 5, 2022
952c2a0
Resolve merge conflicts
amyeroberts Aug 5, 2022
2f0fa0b
Resolve merge conflicts
amyeroberts Aug 5, 2022
e6233cc
Resolve merge conflicts
amyeroberts Aug 5, 2022
ddc8cf9
Merge branch 'image-processor-glpn' into rename-fe-to-ip-glpn
amyeroberts Aug 5, 2022
5407de6
Merge in main
amyeroberts Aug 5, 2022
f1cf228
Merge branch 'image-transforms-library' into image-processor-mixin
amyeroberts Aug 5, 2022
bd0afd6
Merge branch 'image-processor-mixin' into base-image-processor-class
amyeroberts Aug 5, 2022
a6f69bc
Merge branch 'base-image-processor-class' into image-batch-feature
amyeroberts Aug 5, 2022
a7af81f
Merge and resolve conflicts
amyeroberts Aug 5, 2022
ad58bd9
Merge branch 'image-processor-glpn' into rename-fe-to-ip-glpn
amyeroberts Aug 5, 2022
affb945
Add return_numpy=False flag for resize
amyeroberts Aug 7, 2022
5891dd8
Merge branch 'image-transforms-library' into image-processor-mixin
amyeroberts Aug 7, 2022
b66d0f6
Merge branch 'image-processor-mixin' into base-image-processor-class
amyeroberts Aug 7, 2022
8b73f89
Merge branch 'base-image-processor-class' into image-batch-feature
amyeroberts Aug 7, 2022
ae6030c
Merge branch 'image-batch-feature' into image-processor-glpn
amyeroberts Aug 7, 2022
78bdfb3
Merge branch 'image-processor-glpn' into rename-fe-to-ip-glpn
amyeroberts Aug 7, 2022
7a4d22a
Fix up
amyeroberts Aug 8, 2022
994e040
Fix up
amyeroberts Aug 8, 2022
42c23bd
Use different frameworks safely
amyeroberts Aug 8, 2022
05c65f6
Safely import PIL
amyeroberts Aug 8, 2022
feb9556
Call function checking if PIL available
amyeroberts Aug 8, 2022
a30b007
Only import if vision available
amyeroberts Aug 8, 2022
fd7b6c7
Address Sylvain PR comments
amyeroberts Aug 9, 2022
790c2c6
Apply suggestions from code review
amyeroberts Aug 10, 2022
2e929cf
Update src/transformers/image_transforms.py
amyeroberts Aug 12, 2022
ff04de3
Update src/transformers/models/glpn/feature_extraction_glpn.py
amyeroberts Aug 12, 2022
cb1dcd8
Merge pull request #25 from amyeroberts/image-processor-mixin
amyeroberts Aug 16, 2022
ae35873
Add in docstrings
amyeroberts Aug 17, 2022
4fff267
Merge pull request #23 from amyeroberts/image-processor-glpn
amyeroberts Aug 17, 2022
62c6e55
Merge pull request #26 from amyeroberts/image-batch-feature
amyeroberts Aug 17, 2022
271a09d
Merge pull request #24 from amyeroberts/rename-fe-to-ip-glpn
amyeroberts Aug 17, 2022
b7edea0
Fix TFSwinSelfAttention to have relative position index as non-traina…
harrydrippin Aug 5, 2022
8cacf30
Refactor `TFSwinLayer` to increase serving compatibility (#18352)
harrydrippin Aug 5, 2022
e385c5a
Add TF prefix to TF-Res test class (#18481)
ydshieh Aug 5, 2022
ed4b059
Remove py.typed (#18485)
sgugger Aug 5, 2022
553be89
Fix pipeline tests (#18487)
sgugger Aug 5, 2022
cfa16eb
Use new huggingface_hub tools for download models (#18438)
sgugger Aug 5, 2022
7472b39
Fix `test_dbmdz_english` by updating expected values (#18482)
ydshieh Aug 5, 2022
35a534a
Move cache folder to huggingface/hub for consistency with hf_hub (#18…
sgugger Aug 5, 2022
2c96675
Update some expected values in `quicktour.mdx` for `resampy 0.3.0` (#…
ydshieh Aug 5, 2022
fc87969
Forgot one new_ for cache migration
sgugger Aug 5, 2022
e8f5772
disable Onnx test for google/long-t5-tglobal-base (#18454)
ydshieh Aug 5, 2022
707c0ff
Typo reported by Joel Grus on TWTR (#18493)
julien-c Aug 5, 2022
0ff35a8
Just re-reading the whole doc every couple of months 😬 (#18489)
julien-c Aug 6, 2022
1d34656
`transformers-cli login` => `huggingface-cli login` (#18490)
julien-c Aug 6, 2022
a610155
Add seed setting to image classification example (#18519)
regisss Aug 8, 2022
2f493d5
[DX fix] Fixing QA pipeline streaming a dataset. (#18516)
Narsil Aug 8, 2022
80c33f8
Clean up hub (#18497)
sgugger Aug 8, 2022
c6e979f
update fsdp docs (#18521)
pacman100 Aug 8, 2022
2884397
Fix compatibility with 1.12 (#17925)
sgugger Aug 8, 2022
e9ba674
Remove debug statement
sgugger Aug 8, 2022
dcb1685
Specify en in doc-builder README example (#18526)
ankrgyl Aug 8, 2022
7072f66
New cache fixes: add safeguard before looking in folders (#18522)
sgugger Aug 8, 2022
c5e228e
unpin resampy (#18527)
ydshieh Aug 8, 2022
6952e9b
✨ update to use interlibrary links instead of Markdown (#18500)
stevhliu Aug 8, 2022
8a18ad9
Add example of multimodal usage to pipeline tutorial (#18498)
stevhliu Aug 8, 2022
e9c67f7
[VideoMAE] Add model to doc tests (#18523)
NielsRogge Aug 8, 2022
7aa5bfd
Update perf_train_gpu_one.mdx (#18532)
mishig25 Aug 8, 2022
5606dba
Update no_trainer.py scripts to include accelerate gradient accumulat…
Rasmusafj Aug 8, 2022
5b29a58
Add Spanish translation of converting_tensorflow_models.mdx (#18512)
donelianc Aug 8, 2022
defa14c
Spanish translation of summarization.mdx (#15947) (#18477)
AguilaCudicio Aug 8, 2022
87271d1
Let's not cast them all (#18471)
younesbelkada Aug 8, 2022
a9b2968
fix: data2vec-vision Onnx ready-made configuration. (#18427)
NikeNano Aug 9, 2022
24f688f
Add mt5 onnx config (#18394)
chainyo Aug 9, 2022
c465437
Minor update of `run_call_with_unpacked_inputs` (#18541)
ydshieh Aug 9, 2022
cafb76e
BART - Fix attention mask device issue on copied models (#18540)
younesbelkada Aug 9, 2022
a25b1b3
Adding a new `align_to_words` param to qa pipeline. (#18010)
Narsil Aug 9, 2022
fdd9c95
📝 update metric with evaluate (#18535)
stevhliu Aug 9, 2022
5e8a3d4
Restore _init_weights value in no_init_weights (#18504)
YouJiacheng Aug 9, 2022
ba98271
Clean up comment
sgugger Aug 9, 2022
3a70590
📝 update documentation build section (#18548)
stevhliu Aug 9, 2022
09f36ba
`bitsandbytes` - `Linear8bitLt` integration into `transformers` model…
younesbelkada Aug 10, 2022
ca3833e
TF: XLA-trainable DeBERTa v2 (#18546)
gante Aug 10, 2022
b84379c
Preserve hub-related kwargs in AutoModel.from_pretrained (#18545)
sgugger Aug 10, 2022
8d7065e
TF Examples Rewrite (#18451)
Rocketknight1 Aug 10, 2022
c9c5420
Use commit hash to look in cache instead of calling head (#18534)
sgugger Aug 10, 2022
5d39088
`pipeline` support for `device="mps"` (or any other string) (#18494)
julien-c Aug 10, 2022
0544879
Update philosophy to include other preprocessing classes (#18550)
stevhliu Aug 10, 2022
8b98733
Properly move cache when it is not in default path (#18563)
sgugger Aug 10, 2022
c2fc948
Adds CLIP to models exportable with ONNX (#18515)
unography Aug 10, 2022
fe29e4c
raise atol for MT5OnnxConfig (#18560)
ydshieh Aug 10, 2022
793d978
fix string (#18568)
mrwyattii Aug 10, 2022
8aea331
Segformer TF: fix output size in documentation (#18572)
joihn Aug 11, 2022
db07c44
Fix resizing bug in OWL-ViT (#18573)
alaradirik Aug 11, 2022
5a29d4f
Fix LayoutLMv3 documentation (#17932)
pocca2048 Aug 11, 2022
ad4215f
Skip broken tests
sgugger Aug 11, 2022
a272ed0
Change BartLearnedPositionalEmbedding's forward method signature to s…
donebydan Aug 11, 2022
6d8ab27
german docs translation (#18544)
flozi00 Aug 11, 2022
9d87c2d
Deberta V2: Fix critical trace warnings to allow ONNX export (#18272)
iiLaurens Aug 11, 2022
1c38f1a
[FX] _generate_dummy_input supports audio-classification models for l…
michaelbenayoun Aug 11, 2022
529ac2b
Fix docstrings with last version of hf-doc-builder styler (#18581)
sgugger Aug 11, 2022
8a8a9a1
Bump nbconvert from 6.0.1 to 6.3.0 in /examples/research_projects/lxm…
dependabot[bot] Aug 11, 2022
5a46799
Bump nbconvert in /examples/research_projects/visual_bert (#18566)
dependabot[bot] Aug 11, 2022
5d1df72
fix owlvit tests, update docstring examples (#18586)
alaradirik Aug 11, 2022
f03866f
Return the permuted hidden states if return_dict=True (#18578)
amyeroberts Aug 11, 2022
261f480
Load sharded pt to flax (#18419)
ArthurZucker Aug 12, 2022
ff90f49
Add type hints for ViLT models (#18577)
donelianc Aug 12, 2022
1e7062a
update doc for perf_train_cpu_many, add intel mpi introduction (#18576)
sywangyi Aug 12, 2022
c472b59
typos (#18594)
stas00 Aug 12, 2022
8cd549f
FSDP bug fix for `load_state_dict` (#18596)
pacman100 Aug 12, 2022
b0dea99
Add `TFAutoModelForSemanticSegmentation` to the main `__init__.py` (#…
ydshieh Aug 12, 2022
b93957a
Generate: validate `model_kwargs` (and catch typos in generate argume…
gante Aug 12, 2022
b881653
Supporting seq2seq models for `bitsandbytes` integration (#18579)
younesbelkada Aug 12, 2022
a9a0e18
Add Donut (#18488)
NielsRogge Aug 12, 2022
089ad23
Fix URLs (#18604)
NielsRogge Aug 12, 2022
f1590b2
Update BLOOM parameter counts (#18531)
Muennighoff Aug 12, 2022
b2fe78b
[doc] fix anchors (#18591)
stas00 Aug 12, 2022
c9d8c70
[fsmt] deal with -100 indices in decoder ids (#18592)
stas00 Aug 12, 2022
af92441
small change (#18584)
younesbelkada Aug 12, 2022
2287492
Flax Remat for LongT5 (#17994)
KMFODA Aug 14, 2022
c97b085
mac m1 `mps` integration (#18598)
pacman100 Aug 16, 2022
ea2c992
Change scheduled CIs to use torch 1.12.1 (#18644)
ydshieh Aug 16, 2022
771d6c0
Add checks for some workflow jobs (#18583)
ydshieh Aug 16, 2022
b53ef28
TF: Fix generation repetition penalty with XLA (#18648)
gante Aug 16, 2022
1769f66
Update longt5.mdx (#18634)
flozi00 Aug 16, 2022
b2dc2f3
Update run_translation_no_trainer.py (#18637)
zhoutang776 Aug 16, 2022
ab9d3b4
[bnb] Minor modifications (#18631)
younesbelkada Aug 16, 2022
a316ea3
Examples: add Bloom support for token classification (#18632)
stefan-it Aug 17, 2022
c6751ea
Fix Yolos ONNX export test (#18606)
ydshieh Aug 17, 2022
b7046bc
Fixup
amyeroberts Aug 17, 2022
f8a6b87
Fix up
amyeroberts Aug 17, 2022
b6fd4e3
Resolve conflicts
amyeroberts Aug 17, 2022
a37bce3
Move PIL default arguments inside function for safe imports
amyeroberts Aug 17, 2022
6ec9dbb
Add image utils to toctree
amyeroberts Aug 17, 2022
7693600
Update `rescale` method to reflect changes in #18677
amyeroberts Aug 18, 2022
464a4f2
Update docs/source/en/internal/image_processing_utils.mdx
amyeroberts Aug 23, 2022
713e958
Address Niels PR comments
amyeroberts Aug 23, 2022
4e60a76
Add normalize method to transforms library
amyeroberts Aug 24, 2022
6ec76ff
Apply suggestions from code review - remove defaults to None
amyeroberts Sep 2, 2022
adc0f9d
Merge branch 'main' into image-transforms-library
amyeroberts Sep 28, 2022
df81b6a
Merge branch 'main' into image-transforms-library
amyeroberts Oct 12, 2022
38a2427
Merge branch 'image-transforms-library' into image-transforms-add-nor…
amyeroberts Oct 12, 2022
48a07a1
Fix docstrings and revert to PIL.Image.XXX resampling
amyeroberts Oct 12, 2022
8785229
Some more docstrings and PIL.Image tidy up
amyeroberts Oct 12, 2022
1b2ce56
Merge branch 'image-transforms-library' into image-transforms-add-nor…
amyeroberts Oct 12, 2022
d44fe63
Reorganise arguments so flags by modifiers
amyeroberts Oct 12, 2022
16991f0
Merge branch 'image-transforms-library' into image-transforms-add-nor…
amyeroberts Oct 12, 2022
83330ef
Few last docstring fixes
amyeroberts Oct 12, 2022
eba05e2
Merge branch 'image-transforms-library' into image-transforms-add-nor…
amyeroberts Oct 12, 2022
bad7174
Merge branch 'main' into image-transforms-library
amyeroberts Oct 12, 2022
3191025
Merge branch 'image-transforms-library' into image-transforms-add-nor…
amyeroberts Oct 12, 2022
3dd6e02
Resolve conflicts with main
amyeroberts Oct 12, 2022
292f786
Add normalize to docs
amyeroberts Oct 12, 2022
7ea393c
Accept PIL.Image inputs with deprecation warning
amyeroberts Oct 12, 2022
0f99661
Update src/transformers/image_transforms.py
amyeroberts Oct 17, 2022
799829b
Update warning to include version
amyeroberts Oct 17, 2022
d396382
Trigger CI - hash clash on doc build
amyeroberts Oct 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/internal/image_processing_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Most of those are only useful if you are studying the code of the image processo

## Image Transformations

[[autodoc]] image_transforms.normalize

[[autodoc]] image_transforms.rescale

[[autodoc]] image_transforms.resize
Expand Down
61 changes: 60 additions & 1 deletion src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import warnings
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union

import numpy as np

Expand All @@ -25,11 +26,13 @@

from .image_utils import (
ChannelDimension,
get_channel_dimension_axis,
get_image_size,
infer_channel_dimension_format,
is_jax_tensor,
is_tf_tensor,
is_torch_tensor,
to_numpy_array,
)


Expand Down Expand Up @@ -257,3 +260,59 @@ def resize(
resized_image = np.array(resized_image)
resized_image = to_channel_dimension_format(resized_image, data_format)
return resized_image


def normalize(
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
data_format: Optional[ChannelDimension] = None,
) -> np.ndarray:
"""
Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.

image = (image - mean) / std

Args:
image (`np.ndarray`):
The image to normalize.
mean (`float` or `Iterable[float]`):
The mean to use for normalization.
std (`float` or `Iterable[float]`):
The standard deviation to use for normalization.
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If `None`, will use the inferred format from the input.
"""
if isinstance(image, PIL.Image.Image):
warnings.warn(
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
FutureWarning,
)
# Convert PIL image to numpy array with the same logic as in the previous feature extractor normalize -
# casting to numpy array and dividing by 255.
image = to_numpy_array(image)
image = rescale(image, scale=1 / 255)

input_data_format = infer_channel_dimension_format(image)
channel_axis = get_channel_dimension_axis(image)
num_channels = image.shape[channel_axis]

if isinstance(mean, Iterable):
if len(mean) != num_channels:
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
else:
mean = [mean] * num_channels

if isinstance(std, Iterable):
if len(std) != num_channels:
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
else:
std = [std] * num_channels

if input_data_format == ChannelDimension.LAST:
image = (image - mean) / std
else:
image = ((image.T - mean) / std).T

image = to_channel_dimension_format(image, data_format) if data_format is not None else image
return image
19 changes: 19 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,25 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
raise ValueError("Unable to infer channel dimension format")


def get_channel_dimension_axis(image: np.ndarray) -> int:
"""
Returns the channel dimension axis of the image.

Args:
image (`np.ndarray`):
The image to get the channel dimension axis of.

Returns:
The channel dimension axis of the image.
"""
channel_dim = infer_channel_dimension_format(image)
if channel_dim == ChannelDimension.FIRST:
return image.ndim - 3
elif channel_dim == ChannelDimension.LAST:
return image.ndim - 1
raise ValueError(f"Unsupported data format: {channel_dim}")


def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
"""
Returns the (height, width) dimensions of the image.
Expand Down
23 changes: 23 additions & 0 deletions tests/test_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from transformers.image_transforms import (
get_resize_output_image_size,
normalize,
resize,
to_channel_dimension_format,
to_pil_image,
Expand Down Expand Up @@ -172,3 +173,25 @@ def test_resize(self):
self.assertIsInstance(resized_image, PIL.Image.Image)
# PIL size is in (width, height) order
self.assertEqual(resized_image.size, (40, 30))

def test_normalize(self):
image = np.random.randint(0, 256, (224, 224, 3)) / 255

# Number of mean values != number of channels
with self.assertRaises(ValueError):
normalize(image, mean=(0.5, 0.6), std=1)

# Number of std values != number of channels
with self.assertRaises(ValueError):
normalize(image, mean=1, std=(0.5, 0.6))

# Test result is correct - output data format is channels_first and normalization
# correctly computed
mean = (0.5, 0.6, 0.7)
std = (0.1, 0.2, 0.3)
expected_image = ((image - mean) / std).transpose((2, 0, 1))

normalized_image = normalize(image, mean=mean, std=std, data_format="channels_first")
self.assertIsInstance(normalized_image, np.ndarray)
self.assertEqual(normalized_image.shape, (3, 224, 224))
self.assertTrue(np.allclose(normalized_image, expected_image))
25 changes: 24 additions & 1 deletion tests/utils/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest

from transformers import is_torch_available, is_vision_available
from transformers.image_utils import ChannelDimension
from transformers.image_utils import ChannelDimension, get_channel_dimension_axis
from transformers.testing_utils import require_torch, require_vision


Expand Down Expand Up @@ -535,3 +535,26 @@ def test_infer_channel_dimension(self):
image = np.random.randint(0, 256, (1, 3, 4, 5))
inferred_dim = infer_channel_dimension_format(image)
self.assertEqual(inferred_dim, ChannelDimension.FIRST)

def test_get_channel_dimension_axis(self):
# Test we correctly identify the channel dimension
image = np.random.randint(0, 256, (3, 4, 5))
inferred_axis = get_channel_dimension_axis(image)
self.assertEqual(inferred_axis, 0)

image = np.random.randint(0, 256, (1, 4, 5))
inferred_axis = get_channel_dimension_axis(image)
self.assertEqual(inferred_axis, 0)

image = np.random.randint(0, 256, (4, 5, 3))
inferred_axis = get_channel_dimension_axis(image)
self.assertEqual(inferred_axis, 2)

image = np.random.randint(0, 256, (4, 5, 1))
inferred_axis = get_channel_dimension_axis(image)
self.assertEqual(inferred_axis, 2)

# We can take a batched array of images and find the dimension
image = np.random.randint(0, 256, (1, 3, 4, 5))
inferred_axis = get_channel_dimension_axis(image)
self.assertEqual(inferred_axis, 1)