Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upsample2d #414

Merged
merged 27 commits into from
Feb 23, 2024
Merged

Upsample2d #414

merged 27 commits into from
Feb 23, 2024

Conversation

gboduljak
Copy link
Contributor

@gboduljak gboduljak commented Jan 10, 2024

Proposed changes

Implemented two-dimensional upsampling, using nearest neighbor and bilinear interpolation.
The implementation was tested by comparison to PyTorch. Closes #73.

My only concern is performance. More precisely, nearest neighbor is implemented using as_strided and a reshape, while bilinear is implemented by directly implementing the interpolation equations in Python. However, there is a reshape from (B,H,W,C) to (B,C,H,W). I struggled to get indexing right without this. Any suggestion to avoid this is welcome. Apart from this, there are existing upsampling kernels in MPS, namely MPSCNNUpsamplingBilinear and MPSCNNUpsamplingNearest. Should we use these?

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Testing Script

import numpy as np
import mlx.core as mx
import mlx.nn as nn
import torch
import torch.nn as tnn

b = 32
h = 8
w = 8
c = 64
scales = [2, 3, 4, 5]
num_iter = 100

for _ in range(num_iter):
    imgs = mx.random.normal((b, h, w, c))
    imgs_tc = torch.tensor(imgs.transpose(
        (0, 3, 1, 2)).tolist(), dtype=torch.float32
    )
    for scale in scales:
        np.allclose(
            a=nn.Upsample2d(
                scale=scale,
                mode="bilinear"
            )(imgs).transpose((0, 3, 1, 2)),
            b=tnn.UpsamplingBilinear2d(scale_factor=scale)(imgs_tc)
        )
        np.allclose(
            a=nn.Upsample2d(
                scale=scale,
                mode="nearest"
            )(imgs).transpose((0, 3, 1, 2)),
            b=tnn.UpsamplingNearest2d(scale_factor=scale)(imgs_tc)
        )

@gboduljak gboduljak marked this pull request as ready for review January 10, 2024 02:21
@gboduljak gboduljak changed the title [Draft] Upsample2d Upsample2d Jan 10, 2024
@angeloskath
Copy link
Member

I refactored it a bit and removed the need for transposing anything. It is still going to be significantly slower than a bespoke kernel but it can be merged as a stop gap. @gboduljak let me know what you think of the changes and @awni feel free to review.

return w_a * a + w_b * b + w_c * c + w_d * d


class Upsample2d(Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice PyTorch has a single Upsample class which handles different dimensions. It might be worth making that consistent and then throwing (or supporting) on the dimensions not yet handled.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awni Thank you for this suggestion. I was thinking of implementing upsampling within nn.Upsample. I had a similar idea/comment on #357. There, I went with nn.Pooling instead of nn.MaxPooling1d or nn.MaxPooling2d and @angeloskath suggested we go for a different class based on the dimension or pooling type. Thus, to be consistent with that review, I implemented nn.Upsample2d. In my opinion, 2D upsampling is also the most common use case.
Could you please share your thoughts on whether we want nn.Upsample2d or nn.Upsample, based on what we might have for pooling?

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice! Some minor suggestions inline.

Also regarding the naming choices, it might be worth making it consistent with PyTorch unless there is a good reason not to. Wdyt?

@gboduljak
Copy link
Contributor Author

@angeloskath @awni Thank you for the work on this and the review. I will take a detailed look at this tomorrow and I will address the comments.

gboduljak and others added 5 commits February 9, 2024 01:23
@gboduljak
Copy link
Contributor Author

Also regarding the naming choices, it might be worth making it consistent with PyTorch unless there is a good reason not to. Wdyt?

Please see my comment regarding the naming choice in the unresolved conversation. We sometimes match PyTorch and sometimes we do not. I am not sure what is the best option here. I would love to know @angeloskath's opinion as well.

@gboduljak
Copy link
Contributor Author

I refactored it a bit and removed the need for transposing anything. It is still going to be significantly slower than a bespoke kernel but it can be merged as a stop gap. @gboduljak let me know what you think of the changes and @awni feel free to review.

Many thanks for your rewrite. It looks much better than the original implementation, due to lack of transpose. I also like the refactor based on the scale type. I also removed an unused import.

Should we rename scale to scale_factor?

@awni
Copy link
Member

awni commented Feb 9, 2024

Could you please share your thoughts on whether we want nn.Upsample2d or nn.Upsample, based on what we might have for pooling?

My comment should have been phrased more as a question: is there a good reason to specialize to 2D vs doing a more general implementation?

  • Consistency within our own layers might be one reason (e.g. with Pooling)
  • Simplicity of implementation and/or use could be another

We sometimes match PyTorch and sometimes we do not. I am not sure what is the best option here.

We are inconsistent there. But I think it's ok actually. First priority, choose the best options for MLX (is it simple, is it usable, is the API clear etc). No need to inherit baggage. If that's reasonably satisfied then consistency is a good follow on. Since we already are not consistent I don't think we should be too draconian about it, just aim for it when it doesn't get in the way.

@gboduljak
Copy link
Contributor Author

@awni Here are my thoughts on this.

My comment should have been phrased more as a question: is there a good reason to specialize to 2D vs doing a more general implementation?

To the best of my knowledge, this upsampling operation is most commonly used in computer vision and it usually assumes either 2D or 3D input data. The main reason why I focused on 2D upsampling is that 2D vision is currently better supported in mlx than 3D. However, we can (easily) generalize the current implementation to 3D, but we need to adapt the logic for interpolation. We can infer whether to use 2D or 3D upsampling based on runtime shapes. Since we do not have 3D implemented yet, I think we should simply raise an error that the 3D operation is not supported.

To conclude, I think it is best to rename this layer to nn.Upsample, as PyTorch and throw the error when not implemented operation is requested. @angeloskath, what do you think?

I plan to resolve the conflicts and do the refactor today.

@gboduljak
Copy link
Contributor Author

@awni I renamed Upsample2d to Upsample and scale to scale_factor. This makes our naming consistent with PyTorch. However, there is a small inconsistency that we should take into account before merging this. It is about PyTorch's align_corners behavior. For simplicity, I decided not to implement the align_corners behavior in mlx. This implies that our implementation of Upsample is equivalent to PyTorch's Upsample(..., align_corners=True). However, PyTorch's default is Upsample(..., align_corners=False). My questions are:

  1. Should we implement align_corners in this PR?
  2. If we do not want to implement align_corners now, should we discuss our deviation from PyTorch defaults in the documentation?

@awni
Copy link
Member

awni commented Feb 14, 2024

Thanks!

Should we implement align_corners in this PR?

🤔 is it useful to have it? Is it complex to implement?

If we do not want to implement align_corners now, should we discuss our deviation from PyTorch defaults in the documentation?

I think explaining the behavior is a good idea!

I'm fine to leave the default + a little doc and do a follow up for align_corners if/when it's needed. Also fine to include it here if it's a simple + useful addition. I can take a look at their docs to understand more, but if you have thoughts on that please share.

@angeloskath
Copy link
Member

@gboduljak changing the indexes with the following would be enough to implement the align_corners=False case.

idx_y = mx.arange(0, new_H) * ((H - 0.5) / (new_H - 1)) - 0.25
idx_x = mx.arange(0, new_W) * ((W - 0.5) / (new_W - 1)) - 0.25
idx_y = mx.clip(idx_y, 0, H - 1)
idx_x = mx.clip(idx_x, 0, W - 1)

It is probably less effort to support it than explain why it wouldn't be supported.

Regarding explanation I think PyTorch's is not very clean and I would simply say, align_corners=True means that location (0, 0) is the top left corner of the image and (H, W) the bottom right while align_corners=False means that these correspond to the center of the top-left and bottom-right pixels respectively.

Let me know if you 'd rather I make these changes.

@gboduljak
Copy link
Contributor Author

@gboduljak changing the indexes with the following would be enough to implement the align_corners=False case.

Given the simplicity of the implementation, I think we should implement align_corners. I wanted to confirm we indeed want this before implementing it. If it is important to merge this as soon as possible, @angeloskath, please do the changes. Otherwise, I will aim to implement align_corners, docs update and the tests tomorrow.

@angeloskath
Copy link
Member

Ok this is probably the last iteration. I generalized @gboduljak 's code to many dims cause I do think 1d upsampling is needed and since I was gonna generalize why not 3d or n-d as well.

I also fixed the align_corners=False cause the one I had suggested was actually not correct for scales other than 2 but the implementation is equally simple.

One question that I have regarding API. I do think it is kinda useless to have to specify linear vs bilinear vs trilinear. It is all linear. Do you think it makes sense to simply have the mode be linear and describe in the docs that it will be bi, tri, etc? Probably not but wanted to pick your brains.

@awni

@gboduljak
Copy link
Contributor Author

gboduljak commented Feb 17, 2024

First of all, thank you (@angeloskath) for generalizing upsampling to many dimensions.

One question that I have regarding API. I do think it is kinda useless to have to specify linear vs bilinear vs trilinear. It is all linear.

I agree that it is not important to distinguish between linear, bilinear and trilinear. Using just linear results in simpler API and it can be clearly explained in the docs.

@awni
Copy link
Member

awni commented Feb 20, 2024

I agree with you both: seems cleaner to specify just "linear" for those cases.

@angeloskath
Copy link
Member

This should be good to review and merge if all ok.

@angeloskath angeloskath requested a review from awni February 21, 2024 18:36
@angeloskath angeloskath merged commit 22364c4 into ml-explore:main Feb 23, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Upsample support
3 participants