-
Notifications
You must be signed in to change notification settings - Fork 521
Optim-wip: Add JIT support to all transforms & some image parameterizations #821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optim-wip: Add JIT support to all transforms & some image parameterizations #821
Conversation
Looks like there was an error with the Insights install in the tests. Everything optim related is working fine however. |
ab73da6
to
afc3434
Compare
* JIT support for `center_crop`. * Improve some transform tests. * Fix `RandomCrop` transform bug.
4a1ec1c
to
9952809
Compare
* Replace Affine `RandomScale` with Interpolation based variant. Renamed old variant to `RandomScaleAffine`. * `CenterCrop` & `center_crop` now use padding if the crop size is larger than the input dimensions. * Add distributions support to both versions of `RandomScale`. * Improve transform tests.
* Add `torch.distributions.distribution.Distribution` to `NumSeqOrTensorType` type hint.
* Added `TransformationRobustness()` transform. * Fixed bug with `center_crop` padding code, and added related tests to `center_crop` & `CenterCrop`.
6d832db
to
d2583d4
Compare
* Add JIT support `NaturalImage`, `FFTImage`, & `PixelImage`. * Added proper JIT support for `ToRGB`. * Improved `NaturalImage` & `FFTImage` tests, and test coverage.
* Added `ImageParameterization` instance support for `NaturalImage`. This improvement should make it easier to use parameterization enhancements like SharedImage, and will be helpful for custom parameterizations that don't use the standard input variable set (size, channels, batch, & init). * Added asserts to verify `NaturalImage` parameterization inputs are instances or types of `ImageParameterization`.
This should make it easier to work with the ToRGB module as many PyTorch functions still don't work with named dimensions yet.
* The maximum of 4 channels isn't required as we ignore all channels after 3.
The `linear` mode only supports 3D inputs, and `trilinear` only supports 5D inputs. RandomScale only uses 4D inputs, so only `nearest`, `bilinear`, `bicubic`, & `area` are supported.
…transforms-support
…mationRobustness` * Change `RandomRotation` type hint from `NumSeqOrTensorType` to `NumSeqOrTensorOrProbDistType`. * Uncomment `RandomRotation` from `TransformationRobustness` & tests.
I've been using the The optimizing with transparency tutorial notebook requires this PR for the |
Added JIT support to all transforms.
Added JIT tests for all applicable transforms.
There's currently a bug that prevents
CenterCrop
from working with JIT, so I added the@torch.jit.ignore
decorator to it's forward function to avoid the bug. This is so that it won't throw an error if someone usestorch.jit.script
on it. JIT is overriding variables of type List[int] to type Tuple[int, int] for an unknown reason during function calls pytorch#69938JIT is also told to ignore
NChannelsToRGB
due to it's unsupported inner functions.The
ToRGB
didn't support JIT due to the usage of named dimensions, so I implemented a second forward function that is only used when using JIT.Add JIT support
NaturalImage
,FFTImage
, &PixelImage
.Improved
NaturalImage
&FFTImage
tests, and test coverage.Added
ImageParameterization
instance support forNaturalImage
. This improvement should make it easier to use parameterization enhancements like SharedImage, and will be helpful for custom parameterizations that don't use the standard input variable set (size, channels, batch, & init).Added asserts to verify
NaturalImage
parameterization inputs are instances or types ofImageParameterization
.Some transform type hints had to be changed due to a bug in JIT's Union support: JIT RuntimeError: 'Union[Tensor, List[float], List[int]]' object is not subscriptable pytorch#69434
Fixed bug with
RandomCrop
transform.Replaced default
RandomScale
transform with Interpolation based variant. Renamed the old variant toRandomScaleAffine
.CenterCrop
/center_crop
now adds padding if the crop size is larger than the input dimensions.Add distributions support to both versions of
RandomScale
. Ludwig wanted this based off of his original PR.Changed
NumSeqOrTensorType
hint toNumSeqOrTensorOrProbDistType
.Added more comprehensive testing to applicable transforms. Test coverage reported by pytest should now be 100% minus the version specific tests.
Transform torch version checks required for module forward functions are performed in their
__init__
function so that they work with JIT.Add
TransformationRobustness()
transform.standard_transforms
, theTransformationRobustness()
is a convenient helper for applying the standard set of transforms in the correct order. The default values have also been tuned to optimal values.RandomSpatialJitter
transforms, aRandomScale
transform, aRandomRotation
transform, a singleRandomSpatialJitter
transform, and then a final optional CenterCrop or Pad To transform.TransformationRobustness()
, and it makes it easier to use transform robustness properly..JIT support for the InceptionV1 model was added to #655
This PR only changes / adds around 711 lines of code, with the other 1306 lines being from tests that feature a lot of redundant / simple code.