Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PoC] reinstate get_params #7095

Closed
wants to merge 3 commits into from
Closed

[PoC] reinstate get_params #7095

wants to merge 3 commits into from

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Jan 16, 2023

Addresses #7092 (comment). The core idea is the following:

  • add two new static methods: get_params and _get_params_internal (name up for discussion)
  • implement the actual sampling in _get_params_internal
  • Call that from _get_params and get_params and wrangle it in the respective format there

cc @vfdev-5 @datumbox @bjuncek

Copy link
Collaborator Author

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss this design based on RandomErasing. I don't expect any large differences for the other transforms. I'll add them when we are happy with the design.

ratio: Tuple[float, float],
value: Optional[List[float]] = None,
) -> Tuple[int, int, int, int, torch.Tensor]:
img_c, img_h, img_w = get_dimensions(image)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In v2, we use the query_chw idiom, since in general the transform should not be bound to images and we can extract this information from multiple types. Since we already have an image here, we can extract directly and pass that.

) -> Tuple[int, int, int, int, torch.Tensor]:
img_c, img_h, img_w = get_dimensions(image)
i, j, h, w, v = RandomErasing._get_params_internal(
img_c, img_h, img_w, scale, torch.log(torch.tensor(ratio)), value
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing ratio directly, v2 stores the log ratio as instance attribute to avoid the re-computation over and over again.

Comment on lines +126 to +127
if v is None:
v = image
Copy link
Collaborator Author

@pmeier pmeier Jan 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In v2, this whole transformation is a no-op in case the value is None for performance. In v1 we returned the input image instead and performed an "erasing", by replacing the image with itself.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #7092 (comment) for a detailed explanation.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, I took a brief look at it

Instead of having a 3rd _get_params_internal() function and 2 shallow wrappers, could we eliminate one of them?

i.e. just let _get_params() do 99% of the work and massage what needs to be in the public get_params()? This way when/if we remove get_params() eventually, we won't have to change the other methods.

@pmeier
Copy link
Collaborator Author

pmeier commented Jan 18, 2023

i.e. just let _get_params() do 99% of the work and massage what needs to be in the public get_params()? This way when/if we remove get_params() eventually, we won't have to change the other methods.

get_params cannot call _get_params since the former is a static method while the latter is an instance method. The only way to eliminate is to let _get_params call get_params, but for v2 we made a few performance and structure improvements that will be hard to backport. For example, without a third method we can't have extra fields in the return value that we commonly use in v2 to avoid performing an operation in the first place.

Of course we could also create a dummy object that represents self if we want to call _get_params from get_params. IIRC, we only ever access values in there and so a simple namespace would suffice.

As a final note, I don't know if we supported or enforced JIT scriptability on get_params. We need to check that as well.

@pmeier
Copy link
Collaborator Author

pmeier commented Jan 20, 2023

Of course we could also create a dummy object that represents self if we want to call _get_params from get_params. IIRC, we only ever access values in there and so a simple namespace would suffice.

As a final note, I don't know if we supported or enforced JIT scriptability on get_params. We need to check that as well.

I did a little digging. get_params is old concept that almost dates back to the beginning of the transforms API. It was conceptualized in #230 (comment), albeit as instance method. Over the course of the PR that addressed the issue, the design changed to the current static method in #240 (comment). Interestingly, in the beginning get_params was used for all kind of parameter sampling not just random sampling. For example, it was also used for CenterCrop:

@staticmethod
def get_params(img, output_size):
w, h = img.size
th, tw = output_size
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return x1, y1, tw, th

We have similar behavior in v2, although not for CenterCrop since the computation moved into the kernels eventually.

At this point, JIT was non-existent and thus also not part of the design. That changed with #2292, which mostly dealt with unifying the PIL and tensor backend, but also required JIT scriptability. The original plan only comprised deterministic transforms, which makes sense. However, later on the list was extended to also include random transforms. JIT scriptability is useful for inference, but I can't think of a use case of using random transforms for that. Unfortunately, it is unclear from the context whether this was done to achieve full scriptability of the transforms API or if it was backed by an actual use case. cc @vfdev-5 @fmassa

In any case, JIT scriptability of get_params is enforced through our tests. Since the forward of the random transforms call into get_params, scripting the transform covers it as well:

def _test_class_op(transform_cls, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs):
meth_kwargs = meth_kwargs or {}
# test for class interface
f = transform_cls(**meth_kwargs)
scripted_fn = torch.jit.script(f)

_test_class_op(
T.ColorJitter,

(This is just exemplary. You can find JIT tests for all other transforms as well in the same test module)

In conclusion, JIT scriptability of get_params was intentional and is currently enforced for v1. I can't think of a use case where this is useful for random transforms, but maybe I'm missing something.

I think we should tie the decision here to our decision for JIT scriptability in general: if we decide to support it (temporarily or permanently) after all, the same should apply to the temporary get_params for BC. Meaning, if we need to support JIT scriptability, we need to go with what I proposed above. If we don't we maybe can get away with a fake self namespace.

@pmeier
Copy link
Collaborator Author

pmeier commented Jan 20, 2023

In 21946f7 I've added a design that uses a fake self namespace. It just wraps the regular _get_params and thus will be very easy to remove after the deprecation period.

@NicolasHug
Copy link
Member

Thanks a lot for the deep dive Philip

I think we should tie the decision here to our decision for JIT scriptability in general: if we decide to support it (temporarily or permanently) after all, the same should apply to the temporary get_params for BC. Meaning, if we need to support JIT scriptability, we need to go with what I proposed above. If we don't we maybe can get away with a fake self namespace.

I agree, let's postpone this decision until we're clearer about the JIT requirements

@pmeier
Copy link
Collaborator Author

pmeier commented Jan 25, 2023

One option that has slipped my mind so far: if we only want to support get_params until transforms v2 is stable and deprecate and remove it together with v1, we could just simply alias the v1 function:

class MyTransformV2(transforms.Transform):
     ...

     @staticmethod
     def get_params():
         return MyTransformV1.get_params()

@vadimkantorov
Copy link

A naming nitpick: should this be better called sample_params? To reflect in the naming that in most of cases it will do some calls to randomness (an thus might be a good idea to also allow an optional generator= argument as an idiom)

@NicolasHug
Copy link
Member

A naming nitpick: should this be better called sample_params?

Yes, probably, but we're only introducing those for BC - the goal is to be backward-compatible, so we have to keep the same name.

@pmeier
Copy link
Collaborator Author

pmeier commented Jan 31, 2023

Superseded by #7153.

@pmeier pmeier closed this Jan 31, 2023
@pmeier pmeier deleted the get_params branch February 2, 2023 07:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants