Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend support of posterize to all integer and floating dtypes #6847

Merged
merged 5 commits into from
Oct 28, 2022

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Oct 27, 2022

This PR addresses #6840 for posterize. I'm going to over-explain how the kernel works for uint8 as a base for the explanation of the other dtypes.

posterize for integers changes the bit depth of an image. For example, setting bits=3 changes to a bit depth of 3, i.e. 2**3 == 8 different values. This happens by only keeping the first bits bits of each value. Since uint8 works module 2**8 and thus

mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)

is equivalent to

>>> bits=3
>>> 256 - (2 ** (8 - bits))
224
>>> torch.tensor(-(2 ** (8 - bits)), dtype=torch.uint8)
tensor(224, dtype=torch.uint8)

or in binary

>>> bin(224)
'0b11100000'

by &'ing this mask with the image, we achieve the desired result

>>> image = torch.arange(0, 256, dtype=torch.uint8)
>>> posterized_image = image & 224
>>> torch.unique(posterized_image)
tensor([  0,  32,  64,  96, 128, 160, 192, 224], dtype=torch.uint8)

For the other integer dtypes, we can simply use the same scheme. We already have the helper function

def _num_value_bits(dtype: torch.dtype) -> int:

which gives us the bit depths. Plus, since we treat the signed integers as unsigned ones by restricting their minimum to 0, we also don't need to care about the leading bit that encodes the sign.

posterize is thus a (re-)quantization transformation. Ignoring finite precision, floating point images are not quantized. This can easily achieved in four steps:

  1. multiplying the image in the range [0.0, 1.0] by 2 ** bits to get [0.0, 2**bits]
  2. use any kind of rounding op, i.e. floor, ceil, round to get {0.0, 1.0, ..., 2**bits}
  3. clamp to remove the largest value to get {0.0, 1.0, ..., 2**bits - 1} (without we would have 2 ** bits + 1 values)
  4. divide by 2 ** bits to get back to the original range, but now with discrete values, i.e. {0.0, 1 / (2 ** bits), 2 / (2 ** bits), ..., 1 - 1 / (2 ** bits)}

In total this makes the implementation equal for all dtypes with the usual caveat of small differences in float vs. int:

import itertools

from torchvision.prototype.transforms import functional as F
import tqdm
import torch

for _ in tqdm.tqdm(range(10_000)):
    image = torch.randint(0, 256, (3, 512, 512), dtype=torch.uint8)
    posterized_images = [
        F.posterize(F.convert_dtype_image_tensor(image, dtype), bits=3)
        for dtype in [torch.uint8, torch.int32, torch.int64, torch.float32, torch.float64]
    ]

    for image1, image2 in itertools.combinations(posterized_images, 2):
        if image1.is_floating_point() and image2.is_floating_point():
            torch.testing.assert_close(image1, image2, check_dtype=False)
        elif image1.is_floating_point() or image2.is_floating_point():
            try:
                if image1.is_floating_point():
                    image1 = F.convert_image_dtype(image1, image2.dtype)
                else:
                    image2 = F.convert_image_dtype(image2, image1.dtype)
            except RuntimeError:
                # Conversion is not safe. This has nothing to do with posterize
                continue
            torch.testing.assert_close(image1, image2, atol=1, rtol=0)
        else:
            dtype = torch.promote_types(image1.dtype, image2.dtype)
            image1 = F.convert_image_dtype(image1, dtype)
            image2 = F.convert_image_dtype(image2, dtype)
            torch.testing.assert_close(image1, image2)

Performance wise, all integer dtypes are the same. Floating point inputs are slower, but still reasonable:

                                       |  main  |  posterize-dtypes
1 threads: --------------------------------------------------------
      (3, 512, 512), uint8, cpu        |   92   |          95      
      (3, 512, 512), uint8, cuda       |    6   |           5      
      (5, 3, 512, 512), uint8, cpu     |  450   |         451      
      (5, 3, 512, 512), uint8, cuda    |   34   |          34      
      (3, 512, 512), int64, cpu        |        |         112      
      (3, 512, 512), int64, cuda       |        |          55      
      (3, 512, 512), float32, cpu      |        |         214      
      (3, 512, 512), float32, cuda     |        |         108      
      (5, 3, 512, 512), int64, cpu     |        |        2000      
      (5, 3, 512, 512), int64, cuda    |        |         271      
      (5, 3, 512, 512), float32, cpu   |        |        1100      
      (5, 3, 512, 512), float32, cuda  |        |         536      

Times are in microseconds (us).

The proposed implementation here is depending on the shape roughly 2-4x faster than the fallback of converting floats to uint8, perform the computation, and convert back:

def proposal(image):
    F.posterize_image_tensor(image, bits=3)


def fallback(image):
    F.convert_dtype_image_tensor(
        F.posterize_image_tensor(
            F.convert_image_dtype(image, torch.uint8),
            bits=3,
        ),
        image.dtype,
    )
                        |  proposal  |  fallback
1 threads: -------------------------------------
      (3, 512, 512)     |     216    |     589  
      (5, 3, 512, 512)  |    1300    |    5700  

Times are in microseconds (us).

cc @vfdev-5 @datumbox @bjuncek

def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
if image.is_floating_point():
levels = 1 << bits
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels)
Copy link
Contributor

@datumbox datumbox Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the clamp at levels-1? This is the kind of implementation reference I had in mind. Which reference did you use? Also why are we multiplying by 2^bits instead of 2^bits-1 which is supposed to be the max for the specific type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I touched on this in 3. in my top comment. Since the input range for float images is inclusive on the edges, i.e. [0.0, 1.0], image.mul(levels).floor_() gives us levels + 1 values, i.e. {0.0, 1.0, 2.0, ..., levels - 1.0, levels}.

However, we want the kernel to quantize to levels levels. Thus, we need to remove one level. For integer dtypes, the higher values are removed, i.e. the remaining values are {i * 2 ** (bit_depth - bits) for i in range(2 ** bits)}. For example

>>> bits = 3
>>> bit_depth = 8
>>> {i * 2 ** (bit_depth - bits) for i in range(2 ** bits)}
{0, 32, 64, 96, 128, 160, 192, 224}

As you can see the 255 / 256 corresponding to 1.0 in floating point images is missing. Thus, we also clamp that away for floating images.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read your PR description but I still had questions on this. Why not do image.mul(levels-1) to begin with? Multiplying by levels means that upper bound of 1 will go outside of the permitted range of the type. What am I missing here?

Copy link
Collaborator Author

@pmeier pmeier Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this more concrete, let's look at an example:

>>> image = torch.arange(0, 256, dtype=torch.uint8)
>>> bits = 3
>>> output_baseline = image & (2 ** 8 - 2 ** (8 - bits))
>>> torch.unique(output_baseline)
tensor([  0,  32,  64,  96, 128, 160, 192, 224], dtype=torch.uint8)
>>> image = torch.linspace(0, 1, 100)
>>> output1 = image.mul(2 ** bits).floor_().clamp_(0, 2**bits - 1).div_(2 ** bits)
>>> torch.unique(output1.mul(255).byte())
tensor([  0,  31,  63,  95, 127, 159, 191, 223], dtype=torch.uint8)
>>> torch.unique(output1.mul(255))
tensor([  0.0000,  31.8750,  63.7500,  95.6250, 127.5000, 159.3750, 191.2500,
        223.1250])

The proposal in this is not perfect, but the .byte() call above eliminates some nuances.

In contrast if we do what you propose we get

>>> output2 = image.mul(2 ** bits - 1).floor_().div(2 ** bits - 1)
>>> torch.unique(output2.mul(255).byte())
tensor([  0,  36,  72, 109, 145, 182, 218, 255], dtype=torch.uint8)

This is of course also a valid way to posterize an image to a bit depth of 3, but the behavior is divergent from what we and PIL do for integers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation @pmeier. This makes sense. @vfdev-5 thoughts?

Copy link
Collaborator

@vfdev-5 vfdev-5 Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier I think you can skip clamp with the following:

(x * (2 ** bits - 1)).floor() / (2 ** bits)

EDIT: for bits=1 above method gets something unexpected:

x = torch.linspace(0.0, 1.0, steps=20)
bits = 1
(x * (2 ** bits - 1)).floor() / (2 ** bits)

# tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.5000])
# vs 
# tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000])

EDIT2: a better quantization formula skipping clamp

(x * (2 ** bits - 0.5)).floor() / (2 ** bits)

bits=1
# tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000])

bits=3
# tensor([0.0000, 0.0000, 0.0000, 0.1250, 0.1250, 0.2500, 0.2500, 0.2500, 0.3750,
        0.3750, 0.5000, 0.5000, 0.6250, 0.6250, 0.6250, 0.7500, 0.7500, 0.8750,
        0.8750, 0.8750])

Copy link
Collaborator Author

@pmeier pmeier Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, this does not work. While the individual level values look good

>>> output3 = (image * (2 ** bits - 1)).floor() / (2 ** bits)
>>> torch.unique(output3.mul(255).byte())
tensor([  0,  31,  63,  95, 127, 159, 191, 223], dtype=torch.uint8)

the posterized values do not match what we do for the integers:

>>> _ = torch.manual_seed(0)
>>> bits = 3
>>> image_uint8 = torch.randint(0, 256, (3, 3), dtype=torch.uint8)
>>> posterized_image_uint8 = image_uint8 & (2 ** 8 - 2 **(8 - bits))
>>> posterized_image_uint8
tensor([[160,  32,  96],
        [192,  64, 224],
        [192,  96,   0]], dtype=torch.uint8)
>>> image_float32 = F.convert_dtype_image_tensor(image_uint8, torch.float32)
>>> posterized_image_float32 = (image_float32 * (2 ** bits - 1)).floor() / (2 ** bits)
>>> posterized_image_float32 = posterized_image_float32.mul(255).byte()
>>> posterized_image_float32
tensor([[127,  31,  95],
        [159,  31, 191],
        [159,  63,   0]], dtype=torch.uint8)
>>> posterized_image_uint8.int() - posterized_image_float32.int()
tensor([[33,  1,  1],
        [33, 33, 33],
        [33, 33,  0]], dtype=torch.int32)

This comes from the asymmetry of the multiplication and division and is also what you observed for bits=1 above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re EDIT 2: formula still produces different values:

>>> posterized_image_float32 = (image_float32 * (2 ** bits - 0.5)).floor() / (2 ** bits)
>>> posterized_image_float32.mul(255).byte()
tensor([[159,  31,  95],
        [159,  31, 223],
        [159,  95,   0]], dtype=torch.uint8)
>>> posterized_image_uint8.int() - posterized_image_float32.int()
tensor([[ 1,  1,  1],
        [33, 33,  1],
        [33,  1,  0]], dtype=torch.int32)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, using clamp you still have a difference of +/- 1 but OK, let's have clamp. Probably, it is not a big deal in terms of runtime perfs.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @pmeier.

Before merging I would suggest to:

  • Run a few additional checks on your side to confirm that images casted to floats and posterized with your algorithm yield the same results as their uint8 equivalents. Doing a few thousand iterations on random input should be more than enough to confirm it works.
  • Wait for @vfdev-5 input to see if there are any changes on his side.

Finally what are your plans concerning equalize? Are you going to aim for a simple conversion or try to find an expansion of the op in other types? I think extending it to other integers can be more straightforward (not sure if it's going to be equally performant though). Floats will be much trickier.

@pmeier
Copy link
Collaborator Author

pmeier commented Oct 27, 2022

Run a few additional checks on your side to confirm that images casted to floats and posterized with your algorithm yield the same results as their uint8 equivalents. Doing a few thousand iterations on random input should be more than enough to confirm it works.

I've posted a correction check in my top comment. Running it 10_000 times yielded no errors.

Finally what are your plans concerning equalize? Are you going to aim for a simple conversion or try to find an expansion of the op in other types? I think extending it to other integers can be more straightforward (not sure if it's going to be equally performant though). Floats will be much trickier.

I will bring a PR soon. I share your assessment here. Other integer dtypes should be simple. Floats are basically impossible with the current implementation. We moved away from general histogram ops, since they don't support batching. However, our current implementation relies on quantized inputs. I think the easiest will be to just convert into an integer dtype and keep the current implementation as is. The only thing that we need to decide is which dtype we want to convert to. Let's discuss this on the PR when I have benchmarks to assist a decision.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, alternative formula to quantize floating range can be (avoids additional clamp call)

(x * (2 ** bits - 0.5)).floor() / (2 ** bits)

@pmeier
Copy link
Collaborator Author

pmeier commented Oct 28, 2022

As discussed offline, I reverted back to 255 or 8 bits as maximum value for all integer dtypes in 562028c. This will be reinstated in #6830 in case we decide to merge it.

@pmeier pmeier requested review from vfdev-5 and datumbox October 28, 2022 10:03
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @pmeier

@pmeier pmeier merged commit 900982f into pytorch:main Oct 28, 2022
@pmeier pmeier deleted the posterize-dtypes branch October 28, 2022 10:38
facebook-github-bot pushed a commit that referenced this pull request Oct 31, 2022
…es (#6847)

Summary:
* extend support of posterize to all integer and floating dtypes

* remove raise

* revert to fixed value range for integer dtypes

Reviewed By: datumbox

Differential Revision: D40851028

fbshipit-source-id: ebb0460ce9eb414515701303688b16a10dab0dee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants