Skip to content

Commit

Permalink
move Datapoint out of public namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Dec 5, 2022
1 parent 5dc222b commit d675ff4
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 55 deletions.
2 changes: 1 addition & 1 deletion test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_dispatcher_feature_signatures_consistency(self, info):
try:
feature_method = getattr(datapoints.Datapoint, info.id)
feature_method = getattr(datapoints._datapoint.Datapoint, info.id)
except AttributeError:
pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.")

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datapoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import Datapoint, FillType, FillTypeJIT, InputType, InputTypeJIT
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._label import Label, OneHotLabel
from ._mask import Mask
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import BoundingBox, Datapoint, Label
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union

from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Datapoint, Label
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
Mapper,
UnBatcher,
)
from torchvision.prototype.datapoints import BoundingBox, Datapoint, Label, Mask
from torchvision.prototype.datapoints import BoundingBox, Label, Mask
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/_builtin/cub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
Mapper,
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datapoints import BoundingBox, Datapoint, Label
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/_builtin/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datapoints import Datapoint
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down
66 changes: 44 additions & 22 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to


def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
Expand Down Expand Up @@ -77,9 +79,11 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to


def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_saturation(saturation_factor=saturation_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
Expand Down Expand Up @@ -116,9 +120,11 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.


def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_contrast(contrast_factor=contrast_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
Expand Down Expand Up @@ -189,9 +195,11 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc


def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
Expand Down Expand Up @@ -301,9 +309,11 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:


def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_hue(hue_factor=hue_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
Expand Down Expand Up @@ -341,9 +351,11 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to


def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_gamma(gamma=gamma, gain=gain)
elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
Expand Down Expand Up @@ -375,9 +387,11 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:


def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.posterize(bits=bits)
elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits)
Expand All @@ -403,9 +417,11 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:


def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.solarize(threshold=threshold)
elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold)
Expand Down Expand Up @@ -453,9 +469,11 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:


def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return autocontrast_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.autocontrast()
elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt)
Expand Down Expand Up @@ -543,9 +561,11 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:


def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return equalize_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.equalize()
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
Expand Down Expand Up @@ -574,9 +594,11 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:


def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return invert_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.invert()
elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
Expand Down
Loading

0 comments on commit d675ff4

Please sign in to comment.