Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convert_image_dtype to functionals #2078

Merged
merged 20 commits into from
Jun 11, 2020

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Apr 8, 2020

This adds a convert_image_dtype function as discussed in #2060 (comment).

Idea behind this function is to first convert the image into the interval [0.0, 1.0] and afterwards in the desired interval of the given dtype.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick PR!

I have a few comments. Also, I would propose to follow a bit more closely the implementation in TensorFlow, as there are a few cases that need to be taken into account.

Let me know what you think

torchvision/transforms/functional.py Outdated Show resolved Hide resolved
torchvision/transforms/functional.py Outdated Show resolved Hide resolved
torchvision/transforms/functional.py Outdated Show resolved Hide resolved
test/test_transforms.py Outdated Show resolved Hide resolved
@pmeier
Copy link
Collaborator Author

pmeier commented Apr 8, 2020

@fmassa If I understand saturate correctly, it its as simple as this:

image = image / scale_factor(image.dtype)
if saturate:
    image = torch.clamp(image, 0.0, 1.0)
image = image * scale_factor(dtype)

The problem that saturate is solving only arises for floating point images. All other image types are by their nature always limited to their max value. By clamping to [0.0, 1.0] after the first scaling, we can ensure that afterwards no overflow errors can arise.

Am I missing something here, or is it just that simple?

Edit:

I think I understand the problem. In theory my way works if it wasn't for this pesky precision on floating point tensors:

import torch


def convert(x, dtype):
    return x.mul(torch.iinfo(dtype).max).to(dtype)


x = torch.tensor(1.0, dtype=torch.float)
for dtype in (torch.short, torch.int, torch.long):
    print(convert(x, dtype))
tensor(32767, dtype=torch.int16)
tensor(-2147483648, dtype=torch.int32)
tensor(-9223372036854775808)

I will handle this, but I will probably take some time.

@pmeier
Copy link
Collaborator Author

pmeier commented Apr 10, 2020

I think this will take some more work and decisions. I dug into the tf implementation and they are basically splitting this int 4 cases:

  1. float to float
  2. float to int
  3. int to float
  4. int to int

I'll go through them one by one.

@pmeier
Copy link
Collaborator Author

pmeier commented Apr 10, 2020

  1. float to float

This is the simplest one as it is basically just a cast since the intervals are the same. One caveat though: even with saturate=True they do not perform any saturation. I don't see any reason why we shouldn't include it in this case. If we explicitly set saturate=True, the function should honor this. Thoughts?

@pmeier
Copy link
Collaborator Author

pmeier commented Apr 10, 2020

  1. float to int

In tf they basically perform floor(image * (c + 0.5)) where c is the maximum value of the dtype, i.e. torch.iinfo(dtype).max. Lets assume we want to cast to uint2. This maps the values as follows:

{
    [  0, 2/7): 0,
    [2/7, 4/7): 1,
    [4/7, 6/7): 2,
    [6/7,   1]: 3,
}

As you can see the last interval is significantly smaller than the others. In general the last interval is given by [c / (c + 1/2), 1]. Thus, for higher values of c the interval will be even smaller.

IMO we should aim for something like this:

{
    [  0, 1/4): 0,
    [1/4, 1/2): 1,
    [1/2, 4/4): 2,
    [3/4,   1]: 3,
}

We could achieve this with floor(min(image * (c+ 1), c)). I did some timing analysis:

import timeit
import torch


x = torch.ones((1, 3, 256, 256))
dtype = torch.uint8
c = float(torch.iinfo(dtype).max)

def theirs(x):
    return x.mul(c + 0.5).to(dtype)

def ours(x):
    return x.mul(c + 1.0).clamp(c).to(dtype)

number = 10000

their_time = timeit.timeit(lambda: theirs(x), number=number)
print(f"their time: {their_time /number * 1e6:.2f} µs")

our_time = timeit.timeit(lambda: ours(x), number=number)
print(f"our time: {our_time / number * 1e6:.2f} µs")

rel_diff = our_time / their_time - 1.0
print(f"rel. diff.: {rel_diff:+.1%}")
their time: 166.69 µs
our time: 200.31 µs
rel. diff.: +20.2%

Mileage may vary for different systems or runs. While this is significant relative increase, I think the absolute difference from about 40 µs is probably acceptable.Thoughts?

@pmeier
Copy link
Collaborator Author

pmeier commented Apr 10, 2020

  1. int to float

They cast to dtype first and subsequently divide by torch.iinfo(image.dtype).max. I would do it the same.

@fmassa
Copy link
Member

fmassa commented Apr 14, 2020

Hi Philip,

About your points:

1 -float to float and saturate

I'm not sure if we should pay a (fairly large) runtime penalty for the saturation check. We should pretty much never encounter any value larger than 3.4028234663852886e+38 for an image (and if we do encounter, this is probably an error on the user side).

2 - float to int and unbalanced last element

This is a fair point, and it seems like TF implementation is suboptimal.
I would propose an alternate solution though, which avoids the clamp (and thus makes things a bit faster)
Instead of doing

floor(min(image * (c + 1), c))

why not do instead

floor(image * (c + 1 - eps))

where eps is say 0.001?
Here is a quick test case:

a = torch.linspace(0, 1, 10001)
print(a.mul(127.999).floor().int().bincount())

yields

tensor([79, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78,
        78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78,
        78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78,
        78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79,
        78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78,
        78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78,
        78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78,
        78, 79])

while

print(a.mul(3.999).floor().int().bincount())

gives

tensor([2501, 2501, 2500, 2499])

@fmassa
Copy link
Member

fmassa commented Apr 14, 2020

Also, as a general note, I think it might be better to move the functions inside the main function either outside, or inline the code in the main function. They are very short anyway, and are only called once so no point in having them as a function (plus we pay an overhead of having to re-define the function at every function call, and it makes it harder for torchscript as well)

My preference would be to inline the helper functions in the main code

@pmeier
Copy link
Collaborator Author

pmeier commented Apr 20, 2020

We should pretty much never encounter any value larger than 3.4028234663852886e+38 for an image (and if we do encounter, this is probably an error on the user side).

Maybe I still got the saturate flag wrong: should it enforce the correct value range after passing a tensor through convert_image_dtype or should it only prevent overflow? I was under the impression the former is the case but you seem to imply you only want the latter.

why not do instead floor(image * (c + 1 - eps))

Fair point. While experimenting with it I've encountered another problem (same for my approach): float32 can only handle integers up to 2 ** 23

Integers between 2 ** n and 2 ** (n+1) round to a multiple of 2 ** (n-23) (notation mine)

If we for example want to convert an float32 to int32 image n==30 since the highest int32 value is c = 2 ** 31 - 1. That means integers in the upper intervals will be rounded to the next multiple of 2 ** 7 = 128. Consider the following:

import torch

c = float(torch.iinfo(torch.int32).max)
eps = 1e-3

image = torch.tensor(1.0, dtype=torch.float)

scaled_images = (
    image * (c + 1 - eps),
    image * (c + 0.5),
    image * (c + 1) - 64,
    image * (c + 1) - 65,
)

print("\n".join([str(image.to(torch.int32)) for image in scaled_images]))
tensor(-2147483648, dtype=torch.int32)
tensor(-2147483648, dtype=torch.int32)
tensor(-2147483648, dtype=torch.int32)
tensor(2147483520, dtype=torch.int32)

For our example we have to at least subtract 2 ** 6 + 1 = 65 to avoid overflow. This number of course changes for other conversions. Ideas of how we (efficiently) want to handle this?


My preference would be to inline the helper functions in the main code

Agreed. I keep them separate until the last commit to help myself keep a better overview.

@fmassa
Copy link
Member

fmassa commented Apr 21, 2020

It all depends on what we mean by saturate. In TF, they only handle overflow / underflow

Note that converting from floating point inputs to integer types may lead to
over/underflow problems. Set saturate to True to avoid such problem in
problematic conversions. If enabled, saturation will clip the output into the
allowed range before performing a potentially dangerous cast (and only before
performing such a cast, i.e., when casting from a floating point to an integer
type, and when casting from a signed to an unsigned type; saturate has no
effect on casts between floats, or on casts that increase the type's range).

I think this definition make sense, and I'm not sure we would want to clamp float values to be within 0-1 inside this function.

Those are good points, and that's probably why the TF implementation has so many conditionals -- to make the implementation fast when possible.
Here is what they do in this case, which involves a bit more conversions. I think we might need to do something similar (and thus have a few different branches in our implementation)

@pmeier
Copy link
Collaborator Author

pmeier commented Apr 23, 2020

I don't know how or if that works for them. I've converted the float to int conversion in torch:

import torch


def saturate_cast(value: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
    def info(dtype):
        if dtype.is_floating_point:
            return torch.finfo(dtype)
        else:
            return torch.iinfo(dtype)

    input_info = info(value.dtype)
    output_info = info(dtype)

    if input_info.min < output_info.min:
        value = torch.max(value, torch.tensor(output_info.min, dtype=value.dtype))
    if input_info.max > output_info.max:
        value = torch.min(value, torch.tensor(output_info.max, dtype=value.dtype))
    return value.to(dtype)


image = torch.tensor(1.0, dtype=torch.float32)
dtype = torch.int32

scale = torch.iinfo(dtype).max + 0.5
scaled = image * scale

print(scaled)
print(saturate_cast(scaled, dtype))
tensor(2.1475e+09)
tensor(-2147483648, dtype=torch.int32)

I've expected this much since this does not handle the problematic I've addressed above. I do not have the capability to setup tf. Could you (or someone else) try this in tf and see if they simply missed this or why that works them?

@fmassa
Copy link
Member

fmassa commented Apr 27, 2020

Good point!

I just tried the above snippet with TF (using colab), and got the same results as in PyTorch

import tensorflow as tf
a = tf.fill([1], 2147483647.5, tf.float32)
print(tf.dtypes.saturate_cast(a, dtype=tf.int32))

which gives

<tf.Tensor: shape=(1,), dtype=int32, numpy=array([-2147483648], dtype=int32)>

I'm not sure what's the best approach we should follow here.
If we can find a sufficiently efficient (and simple) implementation that handles those cases, then it would be great.
But if it's not possible, I'd say that this is something that we should live with, and properly document. I wouldn't expect float tensors containing images to be on such large ranges of values, so I think this will in general not be an issue.

@pmeier
Copy link
Collaborator Author

pmeier commented Apr 27, 2020

I'll work something out and get back to you.

@pmeier
Copy link
Collaborator Author

pmeier commented Apr 28, 2020

I've played with it and I don't think this can be handled in an easy or concise way. With a little effort I can safeguard the upper limit, but with that the lower limit is no longer 0 after the cast.

I wouldn't expect float tensors containing images to be on such large ranges of values, so I think this will in general not be an issue.

Either I'm missing your point or I think this assumption is incorrect. This problem applies to every conversion of floating point tensors to int tensors with the same or a higher number of bits. So without further handling the conversion from float32 to int32 is not safe, since 1.0 is a perfectly valid value for the input image.

I'm not sure how to move forward on this.


Edit:

I've found a way to handle both the upper and lower bounds. Let me know what you think:

import torch
import itertools

float_dtypes = (torch.float32, torch.float64)
int_dtypes = (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)

int_nextpow2 = {
    torch.float32: 23,
    torch.float64: 52,
    torch.uint8: 8,
    torch.int8: 7,
    torch.int16: 15,
    torch.int32: 31,
    torch.int64: 63,
}


def float_to_int(x, dtype):
    max = torch.iinfo(dtype).max

    m = int_nextpow2[x.dtype]
    n = int_nextpow2[dtype]

    if m >= n:
        return (x * max).to(dtype)
    else:
        c = 2 ** (n - (m + 1))
        return torch.max((x * max - c).to(dtype) + c - 1, torch.zeros(1, dtype=dtype))


for float_dtype, int_dtype in itertools.product(float_dtypes, int_dtypes):
    x = torch.tensor((0.0, 1.0), dtype=float_dtype)
    y = float_to_int(x, int_dtype)

    actual = tuple(y.tolist())
    desired = (0, torch.iinfo(int_dtype).max)
    if actual != desired:
        print(
            (
                f"Conversion from {float_dtype} to {int_dtype} did not work as "
                f"expected: {actual} != {desired}"
            )
        )

The int_nextpow2 represents the last power of two where two consecutive integers can be differentiated. For integer dtypes that is simply log2(max(int_dtype) + 1). For floating point dtypes it represents the number of fraction bits. (I'm not happy with the name int_nextpow2. If you can think of something better, feel free to share)

The idea is to check if the float dtype can handle max values of the int dtype before the cast. If that is not the case we subtract a constant c to avoid overflow resulting from rounding errors. After the cast we simply add c again. On the lower bound a "clamp" is sufficient, since values around 0 can always be represented accurately with floating point dtypes.

@fmassa
Copy link
Member

fmassa commented May 4, 2020

Hi @pmeier

The issue with the last solution you proposed is that we get back to the original behavior that were trying to fix, which is that now the 255 (for uint8) value has a different distribution than the others (it only appears if exactly 1.0 is passed)

a = torch.linspace(0, 1, 10001)
r = float_to_int(a, torch.uint8).bincount()
print(r)

gives us

tensor([40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39,
        40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39,
        39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39,
        39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39,
        39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39,
        39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39,
        39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40,
        39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40,
        39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39,
        40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39,
        40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39,
        39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39,
        39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39,
        39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39,
        39, 39, 39,  1])

Maybe there is an easy fix for this though (like passing 1 - eps instead of 1 somewhere).

Proposal to move forward

In order to move forward, I would propose that we only allow float -> integer conversion if the dtype allows for the correct behavior, and raise an error if this is not the case. So we would only allow converting float32 to {uint8, int8, int16}, and float64 would allow to convert additionally to int32. This way, we can keep the behavior correct, while postponing the decision on what to do in those corner cases in the future.
I believe we will rarely want to convert float images to dtypes > int16 (as not many image formats support it), so this would be fine for most cases, and if we start having feature requests in the future we can reconsider this.

Thoughts?

@pmeier
Copy link
Collaborator Author

pmeier commented May 4, 2020

The issue with the last solution you proposed is that we get back to the original behavior that were trying to fix

Good catch! Seems I was so focused on fixing this that I forgot that.

In order to move forward, I would propose that we only allow float -> integer conversion if the dtype allows for the correct behavior, and raise an error if this is not the case.

I think that is reasonable. Do you want me to completely disable this or add a force flag?

@fmassa
Copy link
Member

fmassa commented May 4, 2020

Do you want me to completely disable this or add a force flag?

I would say to completely disable this for now, and raise an error (with a good error message) if the user tried to do this. We can then see how many users will complain about this in the future.

Also, one thing I noticed in the TF convert_image_dtype implementation is that they mention that the interval for float data types is expected to be in [0, 1), so it's an open interval and 1 is not supposed to be there. This would "solve" some of the issues you are facing, although I'm not sure this is something we should be doing for now, as clamping to 1 is a very common thing we do.

@fmassa
Copy link
Member

fmassa commented Jun 8, 2020

@pmeier do you think you would have some time to work on this sometime this week? Otherwise I can build on top of it and get it merged.

@pmeier
Copy link
Collaborator Author

pmeier commented Jun 8, 2020

@fmassa Sorry for the hold-up. I'm covered until Friday. If you need this before, feel free to build on top of it. Otherwise I'll work on it on Friday and should get it done if I don't stumble upon another issue that needs discussing.

@fmassa
Copy link
Member

fmassa commented Jun 8, 2020

Sounds good, thanks for the heads up! This can wait until Friday, thanks a lot!

@pmeier
Copy link
Collaborator Author

pmeier commented Jun 9, 2020

@fmassa Maybe we can discuss this before I work on it further: the last missing conversion is int to int. Without going into floating point we can multiply or divide all values by (2 ** m) / (2 ** n), where m and n are the number of bits in the new and original dtype, respectively.

The conversion of a black pixel, i.e. 0, is unproblematic. The conversion of a white pixel, i.e. 2 ** n - 1 ultimately boils down to this:

2 ** m - 1 + floor(1 - 2 ** (m - n))

The first part (2 ** m - 1) is just what we want, but the "error term" might be problematic:

  • n > m: floor(1 - 2 ** (m - n)) == 0
  • n < m: floor(1 - 2 ** (m - n)) == - (2 ** (m - n) - 1) <= -1

Thus, if we convert from higher number of bits to a lower (n > m) everything is fine, but if the conversion is the other way around (n < m) the maximum values are not mapped to each other.

Is this something you want to address further or simply leave it as is? In tf they have the same problem, but it is not documented.

@fmassa
Copy link
Member

fmassa commented Jun 9, 2020

Thus, if we convert from higher number of bits to a lower (n > m) everything is fine, but if the conversion is the other way around (n < m) the maximum values are not mapped to each other.

I think it's fine if we don't map exactly 255 to 2147483647 (or 32767 for int16), as it would be the simplest thing we can do without incurring too much added complexity, and it's an ok trade-off in my opinion.

@pmeier pmeier force-pushed the convert_image_dtype branch from 613f0cd to 28e2fbf Compare June 11, 2020 13:16
@fmassa
Copy link
Member

fmassa commented Jun 11, 2020

nit: adjust_hue would be a first candidate to using convert_image_dtype #2300

but not in this PR, just to keep in mind

@pmeier pmeier requested a review from fmassa June 11, 2020 13:51
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks a lot @pmeier !

As a follow-up PR, could you add tests for torchscript support as well?

@fmassa fmassa merged commit c2e8a00 into pytorch:master Jun 11, 2020
@pmeier
Copy link
Collaborator Author

pmeier commented Jun 11, 2020

As a follow-up PR, could you add tests for torchscript support as well?

Could you point me to an example how to do that?

@fmassa
Copy link
Member

fmassa commented Jun 11, 2020

@pmeier it will basically be another line in the test that checks that fn = torch.jit.script(F.convert_image_dtype) works, and gives the same results as F.convert_image_dtype, see

script_vflip = torch.jit.script(F_t.vflip)
img_tensor = torch.randn(3, 16, 16)
img_tensor_clone = img_tensor.clone()
vflipped_img = F_t.vflip(img_tensor)
vflipped_img_again = F_t.vflip(vflipped_img)
for an example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants