-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
extend support of posterize to all integer and floating dtypes #6847
Conversation
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: | ||
if image.is_floating_point(): | ||
levels = 1 << bits | ||
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain the clamp at levels-1? This is the kind of implementation reference I had in mind. Which reference did you use? Also why are we multiplying by 2^bits instead of 2^bits-1 which is supposed to be the max for the specific type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I touched on this in 3. in my top comment. Since the input range for float images is inclusive on the edges, i.e. [0.0, 1.0]
, image.mul(levels).floor_()
gives us levels + 1
values, i.e. {0.0, 1.0, 2.0, ..., levels - 1.0, levels}
.
However, we want the kernel to quantize to levels
levels. Thus, we need to remove one level. For integer dtypes, the higher values are removed, i.e. the remaining values are {i * 2 ** (bit_depth - bits) for i in range(2 ** bits)}
. For example
>>> bits = 3
>>> bit_depth = 8
>>> {i * 2 ** (bit_depth - bits) for i in range(2 ** bits)}
{0, 32, 64, 96, 128, 160, 192, 224}
As you can see the 255
/ 256
corresponding to 1.0
in floating point images is missing. Thus, we also clamp that away for floating images.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read your PR description but I still had questions on this. Why not do image.mul(levels-1)
to begin with? Multiplying by levels
means that upper bound of 1
will go outside of the permitted range of the type. What am I missing here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make this more concrete, let's look at an example:
>>> image = torch.arange(0, 256, dtype=torch.uint8)
>>> bits = 3
>>> output_baseline = image & (2 ** 8 - 2 ** (8 - bits))
>>> torch.unique(output_baseline)
tensor([ 0, 32, 64, 96, 128, 160, 192, 224], dtype=torch.uint8)
>>> image = torch.linspace(0, 1, 100)
>>> output1 = image.mul(2 ** bits).floor_().clamp_(0, 2**bits - 1).div_(2 ** bits)
>>> torch.unique(output1.mul(255).byte())
tensor([ 0, 31, 63, 95, 127, 159, 191, 223], dtype=torch.uint8)
>>> torch.unique(output1.mul(255))
tensor([ 0.0000, 31.8750, 63.7500, 95.6250, 127.5000, 159.3750, 191.2500,
223.1250])
The proposal in this is not perfect, but the .byte()
call above eliminates some nuances.
In contrast if we do what you propose we get
>>> output2 = image.mul(2 ** bits - 1).floor_().div(2 ** bits - 1)
>>> torch.unique(output2.mul(255).byte())
tensor([ 0, 36, 72, 109, 145, 182, 218, 255], dtype=torch.uint8)
This is of course also a valid way to posterize an image to a bit depth of 3
, but the behavior is divergent from what we and PIL do for integers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pmeier I think you can skip clamp with the following:
(x * (2 ** bits - 1)).floor() / (2 ** bits)
EDIT: for bits=1 above method gets something unexpected:
x = torch.linspace(0.0, 1.0, steps=20)
bits = 1
(x * (2 ** bits - 1)).floor() / (2 ** bits)
# tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.5000])
# vs
# tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
0.5000, 0.5000])
EDIT2: a better quantization formula skipping clamp
(x * (2 ** bits - 0.5)).floor() / (2 ** bits)
bits=1
# tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
0.5000, 0.5000])
bits=3
# tensor([0.0000, 0.0000, 0.0000, 0.1250, 0.1250, 0.2500, 0.2500, 0.2500, 0.3750,
0.3750, 0.5000, 0.5000, 0.6250, 0.6250, 0.6250, 0.7500, 0.7500, 0.8750,
0.8750, 0.8750])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, this does not work. While the individual level values look good
>>> output3 = (image * (2 ** bits - 1)).floor() / (2 ** bits)
>>> torch.unique(output3.mul(255).byte())
tensor([ 0, 31, 63, 95, 127, 159, 191, 223], dtype=torch.uint8)
the posterized values do not match what we do for the integers:
>>> _ = torch.manual_seed(0)
>>> bits = 3
>>> image_uint8 = torch.randint(0, 256, (3, 3), dtype=torch.uint8)
>>> posterized_image_uint8 = image_uint8 & (2 ** 8 - 2 **(8 - bits))
>>> posterized_image_uint8
tensor([[160, 32, 96],
[192, 64, 224],
[192, 96, 0]], dtype=torch.uint8)
>>> image_float32 = F.convert_dtype_image_tensor(image_uint8, torch.float32)
>>> posterized_image_float32 = (image_float32 * (2 ** bits - 1)).floor() / (2 ** bits)
>>> posterized_image_float32 = posterized_image_float32.mul(255).byte()
>>> posterized_image_float32
tensor([[127, 31, 95],
[159, 31, 191],
[159, 63, 0]], dtype=torch.uint8)
>>> posterized_image_uint8.int() - posterized_image_float32.int()
tensor([[33, 1, 1],
[33, 33, 33],
[33, 33, 0]], dtype=torch.int32)
This comes from the asymmetry of the multiplication and division and is also what you observed for bits=1
above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re EDIT 2: formula still produces different values:
>>> posterized_image_float32 = (image_float32 * (2 ** bits - 0.5)).floor() / (2 ** bits)
>>> posterized_image_float32.mul(255).byte()
tensor([[159, 31, 95],
[159, 31, 223],
[159, 95, 0]], dtype=torch.uint8)
>>> posterized_image_uint8.int() - posterized_image_float32.int()
tensor([[ 1, 1, 1],
[33, 33, 1],
[33, 1, 0]], dtype=torch.int32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, using clamp you still have a difference of +/- 1 but OK, let's have clamp. Probably, it is not a big deal in terms of runtime perfs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @pmeier.
Before merging I would suggest to:
- Run a few additional checks on your side to confirm that images casted to floats and posterized with your algorithm yield the same results as their uint8 equivalents. Doing a few thousand iterations on random input should be more than enough to confirm it works.
- Wait for @vfdev-5 input to see if there are any changes on his side.
Finally what are your plans concerning equalize
? Are you going to aim for a simple conversion or try to find an expansion of the op in other types? I think extending it to other integers can be more straightforward (not sure if it's going to be equally performant though). Floats will be much trickier.
I've posted a correction check in my top comment. Running it
I will bring a PR soon. I share your assessment here. Other integer dtypes should be simple. Floats are basically impossible with the current implementation. We moved away from general histogram ops, since they don't support batching. However, our current implementation relies on quantized inputs. I think the easiest will be to just convert into an integer dtype and keep the current implementation as is. The only thing that we need to decide is which dtype we want to convert to. Let's discuss this on the PR when I have benchmarks to assist a decision. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, alternative formula to quantize floating range can be (avoids additional clamp call)
(x * (2 ** bits - 0.5)).floor() / (2 ** bits)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @pmeier
…es (#6847) Summary: * extend support of posterize to all integer and floating dtypes * remove raise * revert to fixed value range for integer dtypes Reviewed By: datumbox Differential Revision: D40851028 fbshipit-source-id: ebb0460ce9eb414515701303688b16a10dab0dee
This PR addresses #6840 for
posterize
. I'm going to over-explain how the kernel works foruint8
as a base for the explanation of the other dtypes.posterize
for integers changes the bit depth of an image. For example, settingbits=3
changes to a bit depth of 3, i.e.2**3 == 8
different values. This happens by only keeping the firstbits
bits of each value. Sinceuint8
works module2**8
and thusvision/torchvision/transforms/functional_tensor.py
Line 792 in add7596
is equivalent to
or in binary
by
&
'ing this mask with the image, we achieve the desired resultFor the other integer dtypes, we can simply use the same scheme. We already have the helper function
vision/torchvision/prototype/transforms/functional/_meta.py
Line 311 in add7596
which gives us the bit depths. Plus, since we treat the signed integers as unsigned ones by restricting their minimum to
0
, we also don't need to care about the leading bit that encodes the sign.posterize
is thus a (re-)quantization transformation. Ignoring finite precision, floating point images are not quantized. This can easily achieved in four steps:[0.0, 1.0]
by2 ** bits
to get[0.0, 2**bits]
floor
,ceil
,round
to get{0.0, 1.0, ..., 2**bits}
{0.0, 1.0, ..., 2**bits - 1}
(without we would have2 ** bits + 1
values)2 ** bits
to get back to the original range, but now with discrete values, i.e.{0.0, 1 / (2 ** bits), 2 / (2 ** bits), ..., 1 - 1 / (2 ** bits)}
In total this makes the implementation equal for all dtypes with the usual caveat of small differences in float vs. int:
Performance wise, all integer dtypes are the same. Floating point inputs are slower, but still reasonable:
The proposed implementation here is depending on the shape roughly 2-4x faster than the fallback of converting floats to
uint8
, perform the computation, and convert back:cc @vfdev-5 @datumbox @bjuncek