Skip to content

Optim-wip: Add new StackImage parameterization & JIT support for SharedImage #833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 17, 2022

Conversation

ProGamerGov
Copy link
Contributor

@ProGamerGov ProGamerGov commented Dec 31, 2021

  • Added SimpleTensorParameterization as a workaround for JIT not supporting nn.ParameterList. It also helps StackImage support tensor inputs. Edit: Later versions of PyTorch now support nn.ParameterList, but the current solution supports all torch versions supported by Captum.
  • Added JIT support for SharedImage. I had to separate the interpolation operations due to a bug with JIT's union support: JIT RuntimeError: 'Union[Tensor, List[float], List[int]]' object is not subscriptable pytorch#69434
  • Added new parameterization called StackImage, that stacks multiple parameterizations (that are can be on different devices) along the batch dimension. Lucid uses this paramterization a few times, and I thought that I could improve on it with the multi-device functionality.
  • Added tests for InputParameterization and ImageParameterization

* Added `SimpleTensorParameterization` as a workaround for JIT not supporting `nn.ParameterList`. It also helps `StackImage` support tensor inputs.
* Added JIT support for `SharedImage`.
* Added new parameterization called `StackImage`, that stacks multiple parameterizations (that are can be on different devices) along the batch dimension.
* Added `AugmentedImageParameterization` class to use a base for `SharedImage` and `StackImage`.
* Removed `PixelImage`'s 3 channel assert, as there was no reason for limitation.
* Added tests for `InputParameterization`, `ImageParameterization`, & `AugmentedImageParameterization`.
* Added JIT support for SharedImage's interpolation operations.
* Unfortunately, JIT support required me to separate SharedImage's bilinear and trilinear resizing into separate functions as Union's of tuples are currently broken. Union support was also a newer addition, so now SharedImage can support older PyTorch versions as well.
@ProGamerGov ProGamerGov force-pushed the optim-wip-jit-images-support branch from 3d442ec to 9b7ce93 Compare January 1, 2022 16:02
@ProGamerGov
Copy link
Contributor Author

The test_py37_conda test failed due to a stochastic error, and thus it can be ignored.

@NarineK
Copy link
Contributor

NarineK commented May 15, 2022

@ProGamerGov, are there any specific benefits of adding jit support ? I thought that the models that we work with aren't jit-scripted. I see that we already merged a large PR without reviewing related to jit. I was thinking that it could be added later because we need to do additional reviews for jit and @vivekmig would be the best POC for it.

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 15, 2022

@NarineK The image parameterizations and transforms can very much benefit from being JIT scripted, without having to JIT script the models themselves. The other JIT PR covered all of the image parameterizations except the final ones in this PR, so that it's size wouldn't be too large.

Though, we can skip this PR for now if you want.

@NarineK
Copy link
Contributor

NarineK commented May 15, 2022

@ProGamerGov, what kind of benefit do we get from JIT for transformations and image parametrization ? Is it runtime performance ? Did we measure the benefit ?

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 15, 2022

@NarineK In my limited testing, JIT scripting the image parameterizations and transforms can shave off small amount of time for standard rendering, though it will vary based on the parameters / settings used. I haven't really done an in-depth dive into the performance changes though, so the results could vary.

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 15, 2022

@NarineK Upon testing, there is a clear difference in terms of speed when using JIT scripted parameterizations and transforms:

# https://docs.python.org/3/library/timeit.html
import timeit

setup="""
import captum.optim as opt
import torch
test_transform = torch.jit.script(opt.transforms.TransformationRobustness())
image = torch.jit.script(opt.images.NaturalImage((224, 224), batch=4).cuda())
"""

code="""
_ = test_transform(image())
"""

for _ in range(10):
    output = timeit.timeit(stmt=code, setup=setup, number=5000)
    print(output)

Remove the torch.jit.script calls in the above code to test without JIT.

When I tested the above code on Colab, I got the following results:

# With JIT
10.923049781000032
11.118382118
11.57444840200003
10.89497750999999
10.775770954999984
10.848362547999955
10.786963311000022
10.801879264000036
10.77277677799998
10.866098210000018
# Without JIT
14.257542533999981
14.208807915999955
14.269374133000042
14.300639500999978
14.169171004999953
14.305170376999968
14.284609974000091
14.262493259999928
14.304441671000063
14.142993052999941

@NarineK
Copy link
Contributor

NarineK commented May 16, 2022

Thank you, @ProGamerGov ! I'm rerunning the tests. Hopefully the conda issue will get fixed.

@ProGamerGov
Copy link
Contributor Author

@NarineK Looks like the Conda test failed again due to a stochastic error

@NarineK
Copy link
Contributor

NarineK commented May 17, 2022

It's not clear to me that it is a flaky error caused by a stochastic behavior. It is probably a version issue. Some people were able to fix it by downgrading conda version. Let's leave it as is and see if we are seeing the same issue on master. I don't see it in master PRs.

@NarineK NarineK merged commit 5e18711 into pytorch:optim-wip May 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants