-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upsample2d #414
Upsample2d #414
Conversation
585f868
to
69f0643
Compare
I refactored it a bit and removed the need for transposing anything. It is still going to be significantly slower than a bespoke kernel but it can be merged as a stop gap. @gboduljak let me know what you think of the changes and @awni feel free to review. |
python/mlx/nn/layers/upsample.py
Outdated
return w_a * a + w_b * b + w_c * c + w_d * d | ||
|
||
|
||
class Upsample2d(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice PyTorch has a single Upsample
class which handles different dimensions. It might be worth making that consistent and then throwing (or supporting) on the dimensions not yet handled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awni Thank you for this suggestion. I was thinking of implementing upsampling within nn.Upsample
. I had a similar idea/comment on #357. There, I went with nn.Pooling
instead of nn.MaxPooling1d
or nn.MaxPooling2d
and @angeloskath suggested we go for a different class based on the dimension or pooling type. Thus, to be consistent with that review, I implemented nn.Upsample2d
. In my opinion, 2D upsampling is also the most common use case.
Could you please share your thoughts on whether we want nn.Upsample2d
or nn.Upsample
, based on what we might have for pooling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice! Some minor suggestions inline.
Also regarding the naming choices, it might be worth making it consistent with PyTorch unless there is a good reason not to. Wdyt?
@angeloskath @awni Thank you for the work on this and the review. I will take a detailed look at this tomorrow and I will address the comments. |
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Please see my comment regarding the naming choice in the unresolved conversation. We sometimes match PyTorch and sometimes we do not. I am not sure what is the best option here. I would love to know @angeloskath's opinion as well. |
Many thanks for your rewrite. It looks much better than the original implementation, due to lack of transpose. I also like the refactor based on the scale type. I also removed an unused import. Should we rename |
My comment should have been phrased more as a question: is there a good reason to specialize to 2D vs doing a more general implementation?
We are inconsistent there. But I think it's ok actually. First priority, choose the best options for MLX (is it simple, is it usable, is the API clear etc). No need to inherit baggage. If that's reasonably satisfied then consistency is a good follow on. Since we already are not consistent I don't think we should be too draconian about it, just aim for it when it doesn't get in the way. |
@awni Here are my thoughts on this.
To the best of my knowledge, this upsampling operation is most commonly used in computer vision and it usually assumes either 2D or 3D input data. The main reason why I focused on 2D upsampling is that 2D vision is currently better supported in To conclude, I think it is best to rename this layer to I plan to resolve the conflicts and do the refactor today. |
194a052
to
c0f68e8
Compare
@awni I renamed
|
Thanks!
🤔 is it useful to have it? Is it complex to implement?
I think explaining the behavior is a good idea! I'm fine to leave the default + a little doc and do a follow up for |
@gboduljak changing the indexes with the following would be enough to implement the idx_y = mx.arange(0, new_H) * ((H - 0.5) / (new_H - 1)) - 0.25
idx_x = mx.arange(0, new_W) * ((W - 0.5) / (new_W - 1)) - 0.25
idx_y = mx.clip(idx_y, 0, H - 1)
idx_x = mx.clip(idx_x, 0, W - 1) It is probably less effort to support it than explain why it wouldn't be supported. Regarding explanation I think PyTorch's is not very clean and I would simply say, Let me know if you 'd rather I make these changes. |
Given the simplicity of the implementation, I think we should implement |
Ok this is probably the last iteration. I generalized @gboduljak 's code to many dims cause I do think 1d upsampling is needed and since I was gonna generalize why not 3d or n-d as well. I also fixed the One question that I have regarding API. I do think it is kinda useless to have to specify |
First of all, thank you (@angeloskath) for generalizing upsampling to many dimensions.
I agree that it is not important to distinguish between |
I agree with you both: seems cleaner to specify just |
4f0c9e3
to
217905f
Compare
This should be good to review and merge if all ok. |
Proposed changes
Implemented two-dimensional upsampling, using
nearest neighbor
andbilinear
interpolation.The implementation was tested by comparison to PyTorch. Closes #73.
My only concern is performance. More precisely,
nearest neighbor
is implemented usingas_strided
and areshape
, whilebilinear
is implemented by directly implementing the interpolation equations inPython
. However, there is areshape
from(B,H,W,C)
to(B,C,H,W)
. I struggled to get indexing right without this. Any suggestion to avoid this is welcome. Apart from this, there are existing upsampling kernels in MPS, namelyMPSCNNUpsamplingBilinear
andMPSCNNUpsamplingNearest
. Should we use these?Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changesTesting Script