-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathfunctional.py
1579 lines (1256 loc) · 66 KB
/
functional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import math
import numbers
import warnings
from enum import Enum
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from torch import Tensor
try:
import accimage
except ImportError:
accimage = None
from ..utils import _log_api_usage_once
from . import functional_pil as F_pil, functional_tensor as F_t
class InterpolationMode(Enum):
"""Interpolation modes
Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``,
and ``lanczos``.
"""
NEAREST = "nearest"
NEAREST_EXACT = "nearest-exact"
BILINEAR = "bilinear"
BICUBIC = "bicubic"
# For PIL compatibility
BOX = "box"
HAMMING = "hamming"
LANCZOS = "lanczos"
# TODO: Once torchscript supports Enums with staticmethod
# this can be put into InterpolationMode as staticmethod
def _interpolation_modes_from_int(i: int) -> InterpolationMode:
inverse_modes_mapping = {
0: InterpolationMode.NEAREST,
2: InterpolationMode.BILINEAR,
3: InterpolationMode.BICUBIC,
4: InterpolationMode.BOX,
5: InterpolationMode.HAMMING,
1: InterpolationMode.LANCZOS,
}
return inverse_modes_mapping[i]
pil_modes_mapping = {
InterpolationMode.NEAREST: 0,
InterpolationMode.BILINEAR: 2,
InterpolationMode.BICUBIC: 3,
InterpolationMode.NEAREST_EXACT: 0,
InterpolationMode.BOX: 4,
InterpolationMode.HAMMING: 5,
InterpolationMode.LANCZOS: 1,
}
_is_pil_image = F_pil._is_pil_image
def get_dimensions(img: Tensor) -> List[int]:
"""Returns the dimensions of an image as [channels, height, width].
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
List[int]: The image dimensions.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(get_dimensions)
if isinstance(img, torch.Tensor):
return F_t.get_dimensions(img)
return F_pil.get_dimensions(img)
def get_image_size(img: Tensor) -> List[int]:
"""Returns the size of an image as [width, height].
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
List[int]: The image size.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(get_image_size)
if isinstance(img, torch.Tensor):
return F_t.get_image_size(img)
return F_pil.get_image_size(img)
def get_image_num_channels(img: Tensor) -> int:
"""Returns the number of channels of an image.
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
int: The number of channels.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(get_image_num_channels)
if isinstance(img, torch.Tensor):
return F_t.get_image_num_channels(img)
return F_pil.get_image_num_channels(img)
@torch.jit.unused
def _is_numpy(img: Any) -> bool:
return isinstance(img, np.ndarray)
@torch.jit.unused
def _is_numpy_image(img: Any) -> bool:
return img.ndim in {2, 3}
def to_tensor(pic) -> Tensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
This function does not support torchscript.
See :class:`~torchvision.transforms.ToTensor` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_tensor)
if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
default_float_dtype = torch.get_default_dtype()
if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic).to(dtype=default_float_dtype)
# handle PIL Image
mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
if pic.mode == "1":
img = 255 * img
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
def pil_to_tensor(pic: Any) -> Tensor:
"""Convert a ``PIL Image`` to a tensor of the same type.
This function does not support torchscript.
See :class:`~torchvision.transforms.PILToTensor` for more details.
.. note::
A deep copy of the underlying array is performed.
Args:
pic (PIL Image): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(pil_to_tensor)
if not F_pil._is_pil_image(pic):
raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
if accimage is not None and isinstance(pic, accimage.Image):
# accimage format is always uint8 internally, so always return uint8 here
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
pic.copyto(nppic)
return torch.as_tensor(nppic)
# handle PIL Image
img = torch.as_tensor(np.array(pic, copy=True))
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1))
return img
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
This function does not support PIL Image.
Args:
image (torch.Tensor): Image to be converted
dtype (torch.dtype): Desired data type of the output
Returns:
Tensor: Converted image
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(convert_image_dtype)
if not isinstance(image, torch.Tensor):
raise TypeError("Input img should be Tensor Image")
return F_t.convert_image_dtype(image, dtype)
def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
See :class:`~torchvision.transforms.ToPILImage` for more details.
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
Returns:
PIL Image: Image converted to PIL Image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_pil_image)
if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
elif isinstance(pic, torch.Tensor):
if pic.ndimension() not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.")
elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW)
pic = pic.unsqueeze(0)
# check number of channels
if pic.shape[-3] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.")
elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
# check number of channels
if pic.shape[-1] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
npimg = pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != "F":
pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray):
raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}")
if npimg.shape[2] == 1:
expected_mode = None
npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8:
expected_mode = "L"
elif npimg.dtype == np.int16:
expected_mode = "I;16"
elif npimg.dtype == np.int32:
expected_mode = "I"
elif npimg.dtype == np.float32:
expected_mode = "F"
if mode is not None and mode != expected_mode:
raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}")
mode = expected_mode
elif npimg.shape[2] == 2:
permitted_2_channel_modes = ["LA"]
if mode is not None and mode not in permitted_2_channel_modes:
raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs")
if mode is None and npimg.dtype == np.uint8:
mode = "LA"
elif npimg.shape[2] == 4:
permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
if mode is not None and mode not in permitted_4_channel_modes:
raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs")
if mode is None and npimg.dtype == np.uint8:
mode = "RGBA"
else:
permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
if mode is not None and mode not in permitted_3_channel_modes:
raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs")
if mode is None and npimg.dtype == np.uint8:
mode = "RGB"
if mode is None:
raise TypeError(f"Input type {npimg.dtype} is not supported")
return Image.fromarray(npimg, mode=mode)
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
"""Normalize a float tensor image with mean and standard deviation.
This transform does not support PIL Image.
.. note::
This transform acts out of place by default, i.e., it does not mutates the input tensor.
See :class:`~torchvision.transforms.Normalize` for more details.
Args:
tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace.
Returns:
Tensor: Normalized Tensor image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(normalize)
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
def _compute_resized_output_size(
image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> List[int]:
if len(size) == 1: # specified size only for the smallest edge
h, w = image_size
short, long = (w, h) if w <= h else (h, w)
requested_new_short = size if isinstance(size, int) else size[0]
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
else: # specified both h and w
new_w, new_h = size[1], size[0]
return [new_h, new_w]
def resize(
img: Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> Tensor:
r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. warning::
The output image might be different depending on its type: when downsampling, the interpolation of PIL images
and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
closer.
Args:
img (PIL Image or Tensor): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
.. note::
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ``size`` might be overruled, i.e the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes.
This can help making the output for PIL images and tensors closer.
Returns:
PIL Image or Tensor: Resized image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(resize)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if isinstance(size, (list, tuple)):
if len(size) not in [1, 2]:
raise ValueError(
f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
)
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
_, image_height, image_width = get_dimensions(img)
if isinstance(size, int):
size = [size]
output_size = _compute_resized_output_size((image_height, image_width), size, max_size)
if (image_height, image_width) == output_size:
return img
if not isinstance(img, torch.Tensor):
if antialias is not None and not antialias:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)
def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
r"""Pad the given image on all sides with the given "pad" value.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
at most 3 leading dimensions for mode edge,
and an arbitrary number of leading dimensions for mode constant
Args:
img (PIL Image or Tensor): Image to be padded.
padding (int or sequence): Padding on each border. If a single int is provided this
is used to pad all borders. If sequence of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a sequence of length 4 is provided
this is the padding for the left, top, right and bottom borders respectively.
.. note::
In torchscript mode padding as single int is not supported, use a sequence of
length 1: ``[padding, ]``.
fill (number or tuple): Pixel fill value for constant fill. Default is 0.
If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only int or tuple value is supported for PIL Image.
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image.
If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
Returns:
PIL Image or Tensor: Padded image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(pad)
if not isinstance(img, torch.Tensor):
return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
"""Crop the given image at specified location and output size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
Args:
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
Returns:
PIL Image or Tensor: Cropped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(crop)
if not isinstance(img, torch.Tensor):
return F_pil.crop(img, top, left, height, width)
return F_t.crop(img, top, left, height, width)
def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
it is used for both directions.
Returns:
PIL Image or Tensor: Cropped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(center_crop)
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
_, image_height, image_width = get_dimensions(img)
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
_, image_height, image_width = get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop(img, crop_top, crop_left, crop_height, crop_width)
def resized_crop(
img: Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> Tensor:
"""Crop the given image and resize it to desired size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
Args:
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
size (sequence or int): Desired output size. Same semantics as ``resize``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes.
This can help making the output for PIL images and tensors closer.
Returns:
PIL Image or Tensor: Cropped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(resized_crop)
img = crop(img, top, left, height, width)
img = resize(img, size, interpolation, antialias=antialias)
return img
def hflip(img: Tensor) -> Tensor:
"""Horizontally flip the given image.
Args:
img (PIL Image or Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of leading
dimensions.
Returns:
PIL Image or Tensor: Horizontally flipped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(hflip)
if not isinstance(img, torch.Tensor):
return F_pil.hflip(img)
return F_t.hflip(img)
def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
In Perspective Transform each pixel (x, y) in the original image gets transformed as,
(x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
Args:
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
Returns:
octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
"""
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)
for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution
output: List[float] = res.tolist()
return output
def perspective(
img: Tensor,
startpoints: List[List[int]],
endpoints: List[List[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> Tensor:
"""Perform perspective transform of the given image.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
img (PIL Image or Tensor): Image to be transformed.
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
.. note::
In torchscript mode single int/float value is not supported, please use a sequence
of length 1: ``[value, ]``.
Returns:
PIL Image or Tensor: transformed Image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(perspective)
coeffs = _get_perspective_coeffs(startpoints, endpoints)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill)
return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill)
def vflip(img: Tensor) -> Tensor:
"""Vertically flip the given image.
Args:
img (PIL Image or Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of leading
dimensions.
Returns:
PIL Image or Tensor: Vertically flipped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(vflip)
if not isinstance(img, torch.Tensor):
return F_pil.vflip(img)
return F_t.vflip(img)
def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Crop the given image into four corners and the central crop.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. Note::
This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
img (PIL Image or Tensor): Image to be cropped.
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
Returns:
tuple: tuple (tl, tr, bl, br, center)
Corresponding top left, top right, bottom left, bottom right and center crop.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(five_crop)
if isinstance(size, numbers.Number):
size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1:
size = (size[0], size[0])
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
_, image_height, image_width = get_dimensions(img)
crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop(img, 0, 0, crop_height, crop_width)
tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop(img, [crop_height, crop_width])
return tl, tr, bl, br, center
def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]:
"""Generate ten cropped images from the given image.
Crop the given image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. Note::
This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
img (PIL Image or Tensor): Image to be cropped.
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
vertical_flip (bool): Use vertical flipping instead of horizontal
Returns:
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
Corresponding top left, top right, bottom left, bottom right and
center crop and same for the flipped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(ten_crop)
if isinstance(size, numbers.Number):
size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1:
size = (size[0], size[0])
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
first_five = five_crop(img, size)
if vertical_flip:
img = vflip(img)
else:
img = hflip(img)
second_five = five_crop(img, size)
return first_five + second_five
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
"""Adjust brightness of an image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
PIL Image or Tensor: Brightness adjusted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_brightness)
if not isinstance(img, torch.Tensor):
return F_pil.adjust_brightness(img, brightness_factor)
return F_t.adjust_brightness(img, brightness_factor)
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
"""Adjust contrast of an image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image or Tensor: Contrast adjusted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_contrast)
if not isinstance(img, torch.Tensor):
return F_pil.adjust_contrast(img, contrast_factor)
return F_t.adjust_contrast(img, contrast_factor)
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
"""Adjust color saturation of an image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image or Tensor: Saturation adjusted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_saturation)
if not isinstance(img, torch.Tensor):
return F_pil.adjust_saturation(img, saturation_factor)
return F_t.adjust_saturation(img, saturation_factor)
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See `Hue`_ for more details.
.. _Hue: https://en.wikipedia.org/wiki/Hue
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
Note: the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
PIL Image or Tensor: Hue adjusted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_hue)
if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor)
return F_t.adjust_hue(img, hue_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
r"""Perform gamma correction on an image.
Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:
.. math::
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
See `Gamma Correction`_ for more details.
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
Args:
img (PIL Image or Tensor): PIL Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, modes with transparency (alpha channel) are not supported.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
Returns:
PIL Image or Tensor: Gamma correction adjusted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_gamma)
if not isinstance(img, torch.Tensor):
return F_pil.adjust_gamma(img, gamma, gain)
return F_t.adjust_gamma(img, gamma, gain)
def _get_inverse_affine_matrix(
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
# Helper method to compute inverse matrix for affine transformation
# Pillow requires inverse affine transformation matrix:
# Affine matrix is : M = T * C * RotateScaleShear * C^-1
#
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]