Skip to content

Commit

Permalink
Fix bilinear bug and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Feb 6, 2024
1 parent f815683 commit 69f0643
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
17 changes: 9 additions & 8 deletions python/mlx/nn/layers/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mlx.nn.layers.base import Module


def upsample2d_nearest(x: mx.array, scale: Union[Tuple[int, int], Tuple[float, float]]):
def upsample2d_nearest(x: mx.array, scale: Tuple[float, float]):
# Integer scales means we can simply expand-broadcast and reshape
if tuple(map(int, scale)) == scale:
sh, sw = map(int, scale)
Expand All @@ -27,9 +27,7 @@ def upsample2d_nearest(x: mx.array, scale: Union[Tuple[int, int], Tuple[float, f
return x[:, idx_y[:, None], idx_x[None]]


def upsample2d_bilinear(
x: mx.array, scale: Union[Tuple[int, int], Tuple[float, float]]
):
def upsample2d_bilinear(x: mx.array, scale: Tuple[float, float]):
sh, sw = scale
B, H, W, C = x.shape
new_H = int(H * sh)
Expand All @@ -47,8 +45,8 @@ def upsample2d_bilinear(
c = x[:, idx_y_b[:, None], idx_x_l[None]]
d = x[:, idx_y_b[:, None], idx_x_r[None]]
# Compute bilinear interpolation weights
y_weight = (idx_y - idx_y_t)[:, None]
x_weight = (idx_x - idx_x_l)[:, None]
y_weight = (idx_y - idx_y_t)[:, None, None]
x_weight = (idx_x - idx_x_l)[None, :, None]
w_a = (1 - x_weight) * (1 - y_weight)
w_b = x_weight * (1 - y_weight)
w_c = y_weight * (1 - x_weight)
Expand Down Expand Up @@ -131,11 +129,14 @@ def __init__(
super().__init__()
if mode not in ["nearest", "bilinear"]:
raise ValueError("[upsample2d] unsupported upsampling algorithm")
self.scale = tuple(map(float, scale))
if isinstance(scale, (list, tuple)):
self.scale = tuple(map(float, scale))
else:
self.scale = (float(scale), float(scale))
self.mode = mode

def _extra_repr(self) -> str:
return f"scale={self.scale}, mode={self.mode}"
return f"scale={self.scale}, mode={self.mode!r}"

def __call__(self, x: mx.array) -> mx.array:
if self.mode == "bilinear":
Expand Down
33 changes: 17 additions & 16 deletions python/tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,27 +1046,28 @@ def test_upsample2d(self):
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))

# Test repr
self.assertTrue(
str(nn.Upsample2d(scale=2)) == "Upsample2d(scale=2, mode=nearest)"
self.assertEqual(
str(nn.Upsample2d(scale=2)), "Upsample2d(scale=(2.0, 2.0), mode='nearest')"
)
self.assertTrue(
str(nn.Upsample2d(scale=2, mode="nearest"))
== "Upsample2d(scale=2, mode=nearest)"
self.assertEqual(
str(nn.Upsample2d(scale=2, mode="nearest")),
"Upsample2d(scale=(2.0, 2.0), mode='nearest')",
)
self.assertTrue(
str(nn.Upsample2d(scale=2, mode="bilinear"))
== "Upsample2d(scale=2, mode=bilinear)"
self.assertEqual(
str(nn.Upsample2d(scale=2, mode="bilinear")),
"Upsample2d(scale=(2.0, 2.0), mode='bilinear')",
)
self.assertTrue(
str(nn.Upsample2d(scale=(2, 3))) == "Upsample2d(scale=(2, 3), mode=nearest)"
self.assertEqual(
str(nn.Upsample2d(scale=(2, 3))),
"Upsample2d(scale=(2.0, 3.0), mode='nearest')",
)
self.assertTrue(
str(nn.Upsample2d(scale=(2, 3), mode="nearest"))
== "Upsample2d(scale=(2, 3), mode=nearest)"
self.assertEqual(
str(nn.Upsample2d(scale=(2, 3), mode="nearest")),
"Upsample2d(scale=(2.0, 3.0), mode='nearest')",
)
self.assertTrue(
str(nn.Upsample2d(scale=(2, 3), mode="bilinear"))
== "Upsample2d(scale=(2, 3), mode=bilinear)"
self.assertEqual(
str(nn.Upsample2d(scale=(2, 3), mode="bilinear")),
"Upsample2d(scale=(2.0, 3.0), mode='bilinear')",
)


Expand Down

0 comments on commit 69f0643

Please sign in to comment.