From 81a5ae1d8b2ce79c084ea415608c083ed13beb8e Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Fri, 9 Feb 2024 01:34:34 +0100 Subject: [PATCH] improved docs examples readability --- python/mlx/nn/layers/upsample.py | 44 ++++++++------------------------ 1 file changed, 10 insertions(+), 34 deletions(-) diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index cabf9cbd4..165a3ac89 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -84,41 +84,17 @@ class Upsample2d(Module): [[3], [4]]]], dtype=int32) >>> n = nn.Upsample2d(scale=2, mode='nearest') - >>> n(x) - array([[[[1], - [1], - [2], - [2]], - [[1], - [1], - [2], - [2]], - [[3], - [3], - [4], - [4]], - [[3], - [3], - [4], - [4]]]], dtype=int32) + >>> n(x).squeeze() + array([[1, 1, 2, 2], + [1, 1, 2, 2], + [3, 3, 4, 4], + [3, 3, 4, 4]], dtype=int32) >>> b = nn.Upsample2d(scale=2, mode='bilinear') - >>> b(x) - array([[[[1], - [1.33333], - [1.66667], - [2]], - [[1.66667], - [2], - [2.33333], - [2.66667]], - [[2.33333], - [2.66667], - [3], - [3.33333]], - [[3], - [3.33333], - [3.66667], - [4]]]], dtype=float32) + >>> b(x).squeeze() + array([[1, 1.33333, 1.66667, 2], + [1.66667, 2, 2.33333, 2.66667], + [2.33333, 2.66667, 3, 3.33333], + [3, 3.33333, 3.66667, 4]], dtype=float32) """ def __init__(