Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upsample2d #414

Merged
merged 27 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f7adc6f
draft implementation of upsample2d
gboduljak Jan 10, 2024
ea0452c
added tests
gboduljak Jan 10, 2024
4c42e25
docs
gboduljak Jan 13, 2024
485fdfe
added tests for different height and with
gboduljak Jan 13, 2024
3cf5943
tests for _extra_repr
gboduljak Jan 13, 2024
2bced16
added upsample layer to docs
gboduljak Jan 13, 2024
3611099
Update acknowledgements
gboduljak Jan 15, 2024
d9e5283
Merge branch 'main' into upsample-2d
gboduljak Jan 28, 2024
f815683
Refactor Upsample2d
angeloskath Feb 6, 2024
69f0643
Fix bilinear bug and tests
angeloskath Feb 6, 2024
4e586ae
Update python/mlx/nn/layers/upsample.py
gboduljak Feb 9, 2024
adf722f
Update python/mlx/nn/layers/upsample.py
gboduljak Feb 9, 2024
8fb7ddf
Update python/mlx/nn/layers/upsample.py
gboduljak Feb 9, 2024
81a5ae1
improved docs examples readability
gboduljak Feb 9, 2024
57eb900
Merge branch 'upsample-2d' of github.com:gboduljak/mlx into upsample-2d
gboduljak Feb 9, 2024
e548f8f
removed unused import
gboduljak Feb 9, 2024
c0f68e8
Merge branch 'main' into upsample-2d
gboduljak Feb 13, 2024
cea969b
rename to Upsample
gboduljak Feb 14, 2024
2980022
fix docs upsample link
gboduljak Feb 14, 2024
319e2e9
renamed scale to scale_factor
gboduljak Feb 14, 2024
9acad14
Merge branch 'main' into upsample-2d
gboduljak Feb 16, 2024
9ec4556
updated ACKNOWLEDGMENTS.md
gboduljak Feb 16, 2024
04ea7b0
added align_corners
gboduljak Feb 16, 2024
621a84b
Generalize upsample to many dims
angeloskath Feb 17, 2024
8e07a0a
Merge branch 'main' into upsample-2d
angeloskath Feb 19, 2024
4fb92da
Change to linear and update docs
angeloskath Feb 21, 2024
217905f
Fix ACKNOWLEDGMENTS
angeloskath Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ MLX was developed with contributions from the following individuals:
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method, `InstanceNorm`, and `Upsample2d` layers.

<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
Expand Down
1 change: 1 addition & 0 deletions docs/src/python/nn/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ Layers
SinusoidalPositionalEncoding
Step
Transformer
Upsample2d
1 change: 1 addition & 0 deletions python/mlx/nn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,4 @@
TransformerEncoder,
TransformerEncoderLayer,
)
from mlx.nn.layers.upsample import Upsample2d
145 changes: 145 additions & 0 deletions python/mlx/nn/layers/upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright © 2023-2024 Apple Inc.

from typing import List, Literal, Tuple, Union

import mlx.core as mx
from mlx.nn.layers.base import Module


def upsample2d_nearest(x: mx.array, scale: Tuple[float, float]):
# Integer scales means we can simply expand-broadcast and reshape
if tuple(map(int, scale)) == scale:
sh, sw = map(int, scale)
B, H, W, C = x.shape
x = x[:, :, None, :, None]
x = mx.broadcast_to(x, (B, H, sh, W, sw, C))
x = x.reshape(B, H * sh, W * sw, C)
return x

# Floating point scale means we need to do indexing
else:
sh, sw = scale
B, H, W, C = x.shape
new_H = int(H * sh)
new_W = int(W * sw)
idx_y = (mx.arange(0, new_H) / sh).astype(mx.int32)
idx_x = (mx.arange(0, new_W) / sw).astype(mx.int32)
return x[:, idx_y[:, None], idx_x[None]]


def upsample2d_bilinear(x: mx.array, scale: Tuple[float, float]):
sh, sw = scale
B, H, W, C = x.shape
new_H = int(H * sh)
new_W = int(W * sw)
idx_y = mx.arange(0, new_H) * ((H - 1) / (new_H - 1))
idx_x = mx.arange(0, new_W) * ((W - 1) / (new_W - 1))
# Compute the sampling grid
idx_y_t = mx.floor(idx_y).astype(mx.int32)
idx_y_b = mx.ceil(idx_y).astype(mx.int32)
idx_x_l = mx.floor(idx_x).astype(mx.int32)
idx_x_r = mx.ceil(idx_x).astype(mx.int32)
# Sample
a = x[:, idx_y_t[:, None], idx_x_l[None]]
b = x[:, idx_y_t[:, None], idx_x_r[None]]
c = x[:, idx_y_b[:, None], idx_x_l[None]]
d = x[:, idx_y_b[:, None], idx_x_r[None]]
# Compute bilinear interpolation weights
y_weight = (idx_y - idx_y_t)[:, None, None]
x_weight = (idx_x - idx_x_l)[None, :, None]
w_a = (1 - x_weight) * (1 - y_weight)
w_b = x_weight * (1 - y_weight)
w_c = y_weight * (1 - x_weight)
w_d = x_weight * y_weight
# Interpolate
return w_a * a + w_b * b + w_c * c + w_d * d


class Upsample2d(Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice PyTorch has a single Upsample class which handles different dimensions. It might be worth making that consistent and then throwing (or supporting) on the dimensions not yet handled.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awni Thank you for this suggestion. I was thinking of implementing upsampling within nn.Upsample. I had a similar idea/comment on #357. There, I went with nn.Pooling instead of nn.MaxPooling1d or nn.MaxPooling2d and @angeloskath suggested we go for a different class based on the dimension or pooling type. Thus, to be consistent with that review, I implemented nn.Upsample2d. In my opinion, 2D upsampling is also the most common use case.
Could you please share your thoughts on whether we want nn.Upsample2d or nn.Upsample, based on what we might have for pooling?

r"""Upsamples the given spatial data.

The input is assumed to be a 4D tensor where the channels are expected to be last.
Thus, the input shape should be :math:`(N, H, W, C)` where:
- ``N`` is the batch dimension
- ``H`` is the input image height
- ``W`` is the input image width
- ``C`` is the number of input channels

Parameters:
scale (float or Tuple[float, float]): The multiplier for spatial size.
If a single number is provided, the provided value is the
multiplier for both the height and width. Otherwise, the first
element of the tuple is the height multipler, while the second is
the width multipler.
gboduljak marked this conversation as resolved.
Show resolved Hide resolved
mode (str, optional): The upsampling algorithm: one of ``nearest`` and
``bilinear``. Default: ``nearest``.
gboduljak marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))
>>> x
array([[[[1],
[2]],
[[3],
[4]]]], dtype=int32)
>>> n = nn.Upsample2d(scale=2, mode='nearest')
>>> n(x)
array([[[[1],
[1],
[2],
[2]],
[[1],
[1],
[2],
[2]],
[[3],
[3],
[4],
[4]],
[[3],
[3],
[4],
[4]]]], dtype=int32)
gboduljak marked this conversation as resolved.
Show resolved Hide resolved
>>> b = nn.Upsample2d(scale=2, mode='bilinear')
>>> b(x)
array([[[[1],
[1.33333],
[1.66667],
[2]],
[[1.66667],
[2],
[2.33333],
[2.66667]],
[[2.33333],
[2.66667],
[3],
[3.33333]],
[[3],
[3.33333],
[3.66667],
[4]]]], dtype=float32)
"""

def __init__(
self,
scale: Union[float, Tuple[float, float]],
mode: Literal["nearest", "bilinear"] = "nearest",
):
super().__init__()
if mode not in ["nearest", "bilinear"]:
raise ValueError("[upsample2d] unsupported upsampling algorithm")
gboduljak marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(scale, (list, tuple)):
self.scale = tuple(map(float, scale))
else:
self.scale = (float(scale), float(scale))
self.mode = mode

def _extra_repr(self) -> str:
return f"scale={self.scale}, mode={self.mode!r}"

def __call__(self, x: mx.array) -> mx.array:
if self.mode == "bilinear":
return upsample2d_bilinear(x, self.scale)
else:
return upsample2d_nearest(x, self.scale)
185 changes: 185 additions & 0 deletions python/tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,191 @@ def test_dropout3d(self):
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)

def test_upsample2d(self):
b, h, w, c = 1, 2, 2, 1
scale = 2
upsample_nearest = nn.Upsample2d(scale=scale, mode="nearest")
upsample_bilinear = nn.Upsample2d(scale=scale, mode="bilinear")
# Test single feature map
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
expected_nearest = mx.array(
[[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0, 0.333333, 0.666667, 1],
[0.666667, 1, 1.33333, 1.66667],
[1.33333, 1.66667, 2, 2.33333],
[2, 2.33333, 2.66667, 3],
]
]
]
).transpose((0, 2, 3, 1))
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
# Test a more complex batch
b, h, w, c = 2, 3, 3, 2
scale = 2
x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1))

upsample_nearest = nn.Upsample2d(scale=scale, mode="nearest")
upsample_bilinear = nn.Upsample2d(scale=scale, mode="bilinear")

expected_nearest = mx.array(
[
[
[
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
],
[
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
],
],
[
[
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
],
[
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
],
],
]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0.0, 0.4, 0.8, 1.2, 1.6, 2.0],
[1.2, 1.6, 2.0, 2.4, 2.8, 3.2],
[2.4, 2.8, 3.2, 3.6, 4.0, 4.4],
[3.6, 4.0, 4.4, 4.8, 5.2, 5.6],
[4.8, 5.2, 5.6, 6.0, 6.4, 6.8],
[6.0, 6.4, 6.8, 7.2, 7.6, 8.0],
],
[
[9.0, 9.4, 9.8, 10.2, 10.6, 11.0],
[10.2, 10.6, 11.0, 11.4, 11.8, 12.2],
[11.4, 11.8, 12.2, 12.6, 13.0, 13.4],
[12.6, 13.0, 13.4, 13.8, 14.2, 14.6],
[13.8, 14.2, 14.6, 15.0, 15.4, 15.8],
[15.0, 15.4, 15.8, 16.2, 16.6, 17.0],
],
],
[
[
[18.0, 18.4, 18.8, 19.2, 19.6, 20.0],
[19.2, 19.6, 20.0, 20.4, 20.8, 21.2],
[20.4, 20.8, 21.2, 21.6, 22.0, 22.4],
[21.6, 22.0, 22.4, 22.8, 23.2, 23.6],
[22.8, 23.2, 23.6, 24.0, 24.4, 24.8],
[24.0, 24.4, 24.8, 25.2, 25.6, 26.0],
],
[
[27.0, 27.4, 27.8, 28.2, 28.6, 29.0],
[28.2, 28.6, 29.0, 29.4, 29.8, 30.2],
[29.4, 29.8, 30.2, 30.6, 31.0, 31.4],
[30.6, 31.0, 31.4, 31.8, 32.2, 32.6],
[31.8, 32.2, 32.6, 33.0, 33.4, 33.8],
[33.0, 33.4, 33.8, 34.2, 34.6, 35.0],
],
],
]
).transpose((0, 2, 3, 1))
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))

# Test different height and width scale
b, h, w, c = 1, 2, 2, 2
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
upsample_nearest = nn.Upsample2d(scale=(2, 3), mode="nearest")
upsample_bilinear = nn.Upsample2d(scale=(2, 3), mode="bilinear")

expected_nearest = mx.array(
[
[
[
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[2, 2, 2, 3, 3, 3],
[2, 2, 2, 3, 3, 3],
],
[
[4, 4, 4, 5, 5, 5],
[4, 4, 4, 5, 5, 5],
[6, 6, 6, 7, 7, 7],
[6, 6, 6, 7, 7, 7],
],
]
]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0, 0.2, 0.4, 0.6, 0.8, 1],
[0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667],
[1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333],
[2, 2.2, 2.4, 2.6, 2.8, 3],
],
[
[4, 4.2, 4.4, 4.6, 4.8, 5],
[4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667],
[5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333],
[6, 6.2, 6.4, 6.6, 6.8, 7],
],
]
]
).transpose((0, 2, 3, 1))
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))

# Test repr
self.assertEqual(
str(nn.Upsample2d(scale=2)), "Upsample2d(scale=(2.0, 2.0), mode='nearest')"
)
self.assertEqual(
str(nn.Upsample2d(scale=2, mode="nearest")),
"Upsample2d(scale=(2.0, 2.0), mode='nearest')",
)
self.assertEqual(
str(nn.Upsample2d(scale=2, mode="bilinear")),
"Upsample2d(scale=(2.0, 2.0), mode='bilinear')",
)
self.assertEqual(
str(nn.Upsample2d(scale=(2, 3))),
"Upsample2d(scale=(2.0, 3.0), mode='nearest')",
)
self.assertEqual(
str(nn.Upsample2d(scale=(2, 3), mode="nearest")),
"Upsample2d(scale=(2.0, 3.0), mode='nearest')",
)
self.assertEqual(
str(nn.Upsample2d(scale=(2, 3), mode="bilinear")),
"Upsample2d(scale=(2.0, 3.0), mode='bilinear')",
)


if __name__ == "__main__":
unittest.main()