-
Notifications
You must be signed in to change notification settings - Fork 516
Optim-wip: Add new StackImage parameterization & JIT support for SharedImage #833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optim-wip: Add new StackImage parameterization & JIT support for SharedImage #833
Conversation
* Added `SimpleTensorParameterization` as a workaround for JIT not supporting `nn.ParameterList`. It also helps `StackImage` support tensor inputs. * Added JIT support for `SharedImage`. * Added new parameterization called `StackImage`, that stacks multiple parameterizations (that are can be on different devices) along the batch dimension.
* Added `AugmentedImageParameterization` class to use a base for `SharedImage` and `StackImage`. * Removed `PixelImage`'s 3 channel assert, as there was no reason for limitation. * Added tests for `InputParameterization`, `ImageParameterization`, & `AugmentedImageParameterization`.
* Added JIT support for SharedImage's interpolation operations. * Unfortunately, JIT support required me to separate SharedImage's bilinear and trilinear resizing into separate functions as Union's of tuples are currently broken. Union support was also a newer addition, so now SharedImage can support older PyTorch versions as well.
3d442ec
to
9b7ce93
Compare
* Added the `dim` variable to `StackImage` so that users can choose what dimension to stack the image parameterizations across.
The |
@ProGamerGov, are there any specific benefits of adding jit support ? I thought that the models that we work with aren't jit-scripted. I see that we already merged a large PR without reviewing related to jit. I was thinking that it could be added later because we need to do additional reviews for jit and @vivekmig would be the best POC for it. |
@NarineK The image parameterizations and transforms can very much benefit from being JIT scripted, without having to JIT script the models themselves. The other JIT PR covered all of the image parameterizations except the final ones in this PR, so that it's size wouldn't be too large. Though, we can skip this PR for now if you want. |
@ProGamerGov, what kind of benefit do we get from JIT for transformations and image parametrization ? Is it runtime performance ? Did we measure the benefit ? |
@NarineK In my limited testing, JIT scripting the image parameterizations and transforms can shave off small amount of time for standard rendering, though it will vary based on the parameters / settings used. I haven't really done an in-depth dive into the performance changes though, so the results could vary. |
@NarineK Upon testing, there is a clear difference in terms of speed when using JIT scripted parameterizations and transforms:
Remove the When I tested the above code on Colab, I got the following results:
|
Thank you, @ProGamerGov ! I'm rerunning the tests. Hopefully the conda issue will get fixed. |
@NarineK Looks like the Conda test failed again due to a stochastic error |
It's not clear to me that it is a flaky error caused by a stochastic behavior. It is probably a version issue. Some people were able to fix it by downgrading conda version. Let's leave it as is and see if we are seeing the same issue on master. I don't see it in master PRs. |
SimpleTensorParameterization
as a workaround for JIT not supportingnn.ParameterList
. It also helpsStackImage
support tensor inputs. Edit: Later versions of PyTorch now supportnn.ParameterList
, but the current solution supports all torch versions supported by Captum.SharedImage
. I had to separate the interpolation operations due to a bug with JIT's union support: JIT RuntimeError: 'Union[Tensor, List[float], List[int]]' object is not subscriptable pytorch#69434StackImage
, that stacks multiple parameterizations (that are can be on different devices) along the batch dimension. Lucid uses this paramterization a few times, and I thought that I could improve on it with the multi-device functionality.InputParameterization
andImageParameterization