Skip to content

Commit

Permalink
extract make_* functions out of make_*_loader
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jul 3, 2023
1 parent ce7075c commit 1450781
Showing 1 changed file with 37 additions and 9 deletions.
46 changes: 37 additions & 9 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,34 @@ def get_num_channels(color_space):
return num_channels


def make_image(
spatial_size,
*,
color_space="RGB",
batch_dims=(),
dtype=torch.float32,
device="cpu",
constant_alpha=True,
memory_format=torch.contiguous_format,
):
spatial_size = _parse_spatial_size(spatial_size)
num_channels = get_num_channels(color_space)
max_value = get_max_value(dtype)

data = torch.testing.make_tensor(
(*batch_dims, num_channels, *spatial_size),
low=0,
high=max_value,
dtype=dtype,
device=device,
memory_format=memory_format,
)
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
data[..., -1, :, :] = max_value

return datapoints.Image(data)


def make_image_loader(
size="random",
*,
Expand All @@ -505,20 +533,20 @@ def make_image_loader(
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device, memory_format):
max_value = get_max_value(dtype)
data = torch.testing.make_tensor(
shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format
*batch_dims, _, spatial_size = shape
return make_image(
spatial_size,
color_space=color_space,
batch_dims=batch_dims,
dtype=dtype,
device=device,
constant_alpha=constant_alpha,
memory_format=memory_format,
)
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
data[..., -1, :, :] = max_value
return datapoints.Image(data)

return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format)


make_image = from_loader(make_image_loader)


def make_image_loaders(
*,
sizes=DEFAULT_SPATIAL_SIZES,
Expand Down

0 comments on commit 1450781

Please sign in to comment.