Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upsample2d #414

Merged
merged 27 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f7adc6f
draft implementation of upsample2d
gboduljak Jan 10, 2024
ea0452c
added tests
gboduljak Jan 10, 2024
4c42e25
docs
gboduljak Jan 13, 2024
485fdfe
added tests for different height and with
gboduljak Jan 13, 2024
3cf5943
tests for _extra_repr
gboduljak Jan 13, 2024
2bced16
added upsample layer to docs
gboduljak Jan 13, 2024
3611099
Update acknowledgements
gboduljak Jan 15, 2024
d9e5283
Merge branch 'main' into upsample-2d
gboduljak Jan 28, 2024
f815683
Refactor Upsample2d
angeloskath Feb 6, 2024
69f0643
Fix bilinear bug and tests
angeloskath Feb 6, 2024
4e586ae
Update python/mlx/nn/layers/upsample.py
gboduljak Feb 9, 2024
adf722f
Update python/mlx/nn/layers/upsample.py
gboduljak Feb 9, 2024
8fb7ddf
Update python/mlx/nn/layers/upsample.py
gboduljak Feb 9, 2024
81a5ae1
improved docs examples readability
gboduljak Feb 9, 2024
57eb900
Merge branch 'upsample-2d' of github.com:gboduljak/mlx into upsample-2d
gboduljak Feb 9, 2024
e548f8f
removed unused import
gboduljak Feb 9, 2024
c0f68e8
Merge branch 'main' into upsample-2d
gboduljak Feb 13, 2024
cea969b
rename to Upsample
gboduljak Feb 14, 2024
2980022
fix docs upsample link
gboduljak Feb 14, 2024
319e2e9
renamed scale to scale_factor
gboduljak Feb 14, 2024
9acad14
Merge branch 'main' into upsample-2d
gboduljak Feb 16, 2024
9ec4556
updated ACKNOWLEDGMENTS.md
gboduljak Feb 16, 2024
04ea7b0
added align_corners
gboduljak Feb 16, 2024
621a84b
Generalize upsample to many dims
angeloskath Feb 17, 2024
8e07a0a
Merge branch 'main' into upsample-2d
angeloskath Feb 19, 2024
4fb92da
Change to linear and update docs
angeloskath Feb 21, 2024
217905f
Fix ACKNOWLEDGMENTS
angeloskath Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ MLX was developed with contributions from the following individuals:
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.

<a href="https://github.com/ml-explore/mlx/graphs/contributors">
Expand Down Expand Up @@ -253,4 +253,4 @@ Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
1 change: 1 addition & 0 deletions docs/src/python/nn/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ Layers
Softshrink
Step
Transformer
Upsample
1 change: 1 addition & 0 deletions python/mlx/nn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@
TransformerEncoder,
TransformerEncoderLayer,
)
from mlx.nn.layers.upsample import Upsample
205 changes: 205 additions & 0 deletions python/mlx/nn/layers/upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright © 2023-2024 Apple Inc.

import operator
from functools import reduce
from itertools import product
from typing import Literal, Tuple, Union

import mlx.core as mx
from mlx.nn.layers.base import Module


def _scaled_indices(N, scale, align_corners, dim, ndims):
M = int(scale * N)
if align_corners:
indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1))
else:
step = 1 / scale
start = ((M - 1) * step - N + 1) / 2
indices = mx.arange(M, dtype=mx.float32) * step - start
indices = mx.clip(indices, 0, N - 1)
shape = [1] * ndims
shape[dim] = -1

return indices.reshape(shape)


def _nearest_indices(N, scale, dim, ndims):
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)


def _linear_indices(N, scale, align_corners, dim, ndims):
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
indices_l = mx.floor(indices)
indices_r = mx.ceil(indices)
weight = indices - indices_l
weight = mx.expand_dims(weight, -1)

return (
(indices_l.astype(mx.int32), 1 - weight),
(indices_r.astype(mx.int32), weight),
)


def upsample_nearest(x: mx.array, scale_factor: Tuple):
dims = x.ndim - 2
if dims != len(scale_factor):
raise ValueError("A scale needs to be provided for each spatial dimension")

# Integer scale_factors means we can simply expand-broadcast and reshape
if tuple(map(int, scale_factor)) == scale_factor:
shape = list(x.shape)
for d in range(dims):
shape.insert(2 + 2 * d, 1)
x = x.reshape(shape)
for d in range(dims):
shape[2 + 2 * d] = int(scale_factor[d])
x = mx.broadcast_to(x, shape)
for d in range(dims):
shape[d + 1] *= shape[d + 2]
shape.pop(d + 2)
x = x.reshape(shape)
return x

else:
B, *N, C = x.shape
indices = [slice(None)]
for i, (n, s) in enumerate(zip(N, scale_factor)):
indices.append(_nearest_indices(n, s, i, dims))
indices = tuple(indices)

return x[indices]


def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
dims = x.ndim - 2
if dims != len(scale_factor):
raise ValueError("A scale needs to be provided for each spatial dimension")

B, *N, C = x.shape

# Compute the sampling grid
indices = []
for i, (n, s) in enumerate(zip(N, scale_factor)):
indices.append(_linear_indices(n, s, align_corners, i, dims))

# Sample and compute the weights
samples = []
weights = []
for idx_weight in product(*indices):
idx, weight = zip(*idx_weight)
samples.append(x[(slice(None),) + idx])
weights.append(reduce(operator.mul, weight))

# Interpolate
return sum(wi * xi for wi, xi in zip(weights, samples))


class Upsample(Module):
r"""Upsample the input signal spatially.

The spatial dimensions are by convention dimensions ``1`` to ``x.ndim -
2``. The first is the batch dimension and the last is the feature
dimension.

For example, an audio signal would be 3D with 1 spatial dimension, an image
4D with 2 and so on and so forth.

There are two upsampling algorithms implemented nearest neighbor upsampling
and linear interpolation. Both can be applied to any number of spatial
dimensions and the linear interpolation will be bilinear, trilinear etc
when applied to more than one spatial dimension.

.. note::
When using one of the linear interpolation modes the ``align_corners``
argument changes how the corners are treated in the input image. If
``align_corners=True`` then the top and left edge of the input and
output will be matching as will the bottom right edge.

Parameters:
scale_factor (float or tuple): The multiplier for the spatial size.
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
Otherwise, the number of scale factors provided must match the
number of spatial dimensions.
mode (str, optional): The upsampling algorithm, either ``"nearest"`` or
``"linear"``. Default: ``"nearest"``.
align_corners (bool, optional): Changes the way the corners are treated
during ``"linear"`` upsampling. See the note above and the
examples below for more details. Default: ``False``.

Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))
>>> x
array([[[[1],
[2]],
[[3],
[4]]]], dtype=int32)
>>> n = nn.Upsample(scale_factor=2, mode='nearest')
>>> n(x).squeeze()
array([[1, 1, 2, 2],
[1, 1, 2, 2],
[3, 3, 4, 4],
[3, 3, 4, 4]], dtype=int32)
>>> b = nn.Upsample(scale_factor=2, mode='linear')
>>> b(x).squeeze()
array([[1, 1.25, 1.75, 2],
[1.5, 1.75, 2.25, 2.5],
[2.5, 2.75, 3.25, 3.5],
[3, 3.25, 3.75, 4]], dtype=float32)
>>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)
>>> b(x).squeeze()
array([[1, 1.33333, 1.66667, 2],
[1.66667, 2, 2.33333, 2.66667],
[2.33333, 2.66667, 3, 3.33333],
[3, 3.33333, 3.66667, 4]], dtype=float32)
"""

def __init__(
self,
scale_factor: Union[float, Tuple],
mode: Literal["nearest", "linear"] = "nearest",
align_corners: bool = False,
):
super().__init__()
if mode not in ["nearest", "linear"]:
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
if isinstance(scale_factor, (list, tuple)):
self.scale_factor = tuple(map(float, scale_factor))
else:
self.scale_factor = float(scale_factor)
self.mode = mode
self.align_corners = align_corners

def _extra_repr(self) -> str:
return (
f"scale_factor={self.scale_factor}, mode={self.mode!r}, "
f"align_corners={self.align_corners}"
)

def __call__(self, x: mx.array) -> mx.array:
dims = x.ndim - 2
if dims <= 0:
raise ValueError(
f"[Upsample] The input should have at least 1 spatial "
f"dimension which means it should be at least 3D but "
f"{x.ndim}D was provided"
)

scale_factor = self.scale_factor
if isinstance(scale_factor, tuple):
if len(scale_factor) != dims:
raise ValueError(
f"[Upsample] One scale per spatial dimension is required but "
f"scale_factor={scale_factor} and the number of spatial "
f"dimensions were {dims}"
)
else:
scale_factor = (scale_factor,) * dims

if self.mode == "nearest":
return upsample_nearest(x, scale_factor)

else:
return upsample_linear(x, scale_factor, self.align_corners)
Loading