diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f1531bb88..12cd354a0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,26 +5,26 @@ possible. ## Pull Requests -1. Fork and submit pull requests to the repo. +1. Fork and submit pull requests to the repo. 2. If you've added code that should be tested, add tests. 3. If a change is likely to impact efficiency, run some of the benchmarks before and after the change. Examples of benchmarks can be found in `benchmarks/python/`. 4. If you've changed APIs, update the documentation. -5. Every PR should have passing tests and at least one review. +5. Every PR should have passing tests and at least one review. 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. This should install hooks for running `black` and `clang-format` to ensure consistent style for C++ and python code. - + You can also run the formatters manually as follows: - - ``` - clang-format -i file.cpp - ``` - - ``` - black file.py - ``` - + + ```bash + clang-format -i file.cpp + ``` + + ```bash + black file.py + ``` + or run `pre-commit run --all-files` to check all files in the repo. ## Issues diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index a9d3df22d..5b71cf583 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -80,10 +80,8 @@ def predicate(x): _filter = make_predicate(args.filter, args.negative_filter) if args.mlx_dtypes: - compare_filtered = ( - lambda x: compare_mlx_dtypes( - x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1] - ) + compare_filtered = lambda x: ( + compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]) if _filter(x) else None ) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 01ee6d388..87a95986a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2734,7 +2734,8 @@ inline std::vector conv_out_shape( return out_shape; } -inline void run_conv_checks(const array& in, const array& wt, int n_dim) { +inline void +run_conv_checks(const array& in, const array& wt, int n_dim, int n_groups) { if (!is_floating_point(in.dtype()) && kindof(in.dtype()) != Dtype::Kind::c) { std::ostringstream msg; msg << "[conv] Invalid input array with type " << in.dtype() << "." @@ -2767,6 +2768,22 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) { << " input: " << in.shape() << " and weight: " << wt.shape(); throw std::invalid_argument(msg.str()); } + + if (in.shape(n_dim + 1) % n_groups != 0) { + std::ostringstream msg; + msg << "[conv] The number of input channels must be divisible by the number" + << " of groups. Got input with shape " << in.shape() << " and groups " + << n_groups << "."; + throw std::invalid_argument(msg.str()); + } + + if (wt.shape(n_dim + 1) % n_groups != 0) { + std::ostringstream msg; + msg << "[conv] The number of output channels must be divisible by the number" + << " of groups. Got weight with shape " << wt.shape() << " and groups " + << n_groups << "."; + throw std::invalid_argument(msg.str()); + } } } // namespace @@ -2781,15 +2798,15 @@ array conv1d( int groups /* = 1 */, StreamOrDevice s /* = {} */) { // Run checks - if (groups != 1) { - throw std::invalid_argument("[conv1d] Cannot handle groups != 1 yet"); + if (groups < 1) { + throw std::invalid_argument("[conv1d] Invalid groups < 1"); } if (dilation != 1) { throw std::invalid_argument("[conv1d] Cannot handle dilation != 1 yet"); } // Run checks - run_conv_checks(in_, wt_, 1); + run_conv_checks(in_, wt_, 1, groups); auto in = in_; auto wt = wt_; @@ -2802,21 +2819,48 @@ array conv1d( std::vector strides_vec = {stride}; std::vector padding_vec = {padding}; std::vector dilation_vec = {dilation}; + std::vector input_dilation_vec = {1, 1}; - // Get output shapes - std::vector out_shape = conv_out_shape( - in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec); - - return array( - out_shape, - in.dtype(), - std::make_unique( - to_stream(s), - padding_vec, + if (groups == 1) { + // Get output shapes + std::vector out_shape = conv_out_shape( + in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec); + return array( + out_shape, + in.dtype(), + std::make_unique( + to_stream(s), + padding_vec, + strides_vec, + dilation_vec, + input_dilation_vec), + {in, wt}); + } else { + // Grouped convolution + auto in_slices = split(in, groups, -1, s); + auto wt_slices = split(wt, groups, 0, s); + std::vector out_slices; + for (auto i = 0; i < groups; i++) { + auto out_shape = conv_out_shape( + in_slices[i].shape(), + wt_slices[i].shape(), strides_vec, - dilation_vec, - std::vector(1, 1)), - {in, wt}); + padding_vec, + dilation_vec); + auto out_slice = array( + out_shape, + in.dtype(), + std::make_unique( + to_stream(s), + padding_vec, + strides_vec, + dilation_vec, + input_dilation_vec), + {in_slices[i], wt_slices[i]}); + out_slices.push_back(out_slice); + } + return concatenate(out_slices, -1, s); + } } /** 2D convolution with a filter */ @@ -2829,15 +2873,15 @@ array conv2d( int groups /* = 1 */, StreamOrDevice s /* = {} */) { // Run checks - if (groups != 1) { - throw std::invalid_argument("[conv2d] Cannot handle groups != 1 yet"); + if (groups < 1) { + throw std::invalid_argument("[conv2d] Invalid groups < 1"); } if (dilation.first != 1 || dilation.second != 1) { throw std::invalid_argument("[conv2d] Cannot handle dilation != 1 yet"); } // Run checks - run_conv_checks(in_, wt_, 2); + run_conv_checks(in_, wt_, 2, groups); auto in = in_; auto wt = wt_; @@ -2850,21 +2894,49 @@ array conv2d( std::vector strides_vec = {stride.first, stride.second}; std::vector padding_vec = {padding.first, padding.second}; std::vector dilation_vec = {dilation.first, dilation.second}; + std::vector input_dilation_vec = {2, 1}; - // Get output shapes - std::vector out_shape = conv_out_shape( - in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec); + if (groups == 1) { + // Get output shapes + std::vector out_shape = conv_out_shape( + in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec); - return array( - out_shape, - in.dtype(), - std::make_unique( - to_stream(s), - padding_vec, + return array( + out_shape, + in.dtype(), + std::make_unique( + to_stream(s), + padding_vec, + strides_vec, + dilation_vec, + input_dilation_vec), + {in, wt}); + } else { + // Grouped convolution + auto in_slices = split(in, groups, -1, s); + auto wt_slices = split(wt, groups, 0, s); + std::vector out_slices; + for (auto i = 0; i < groups; i++) { + auto out_shape = conv_out_shape( + in_slices[i].shape(), + wt_slices[i].shape(), strides_vec, - dilation_vec, - std::vector(2, 1)), - {in, wt}); + padding_vec, + dilation_vec); + auto out_slice = array( + out_shape, + in.dtype(), + std::make_unique( + to_stream(s), + padding_vec, + strides_vec, + dilation_vec, + input_dilation_vec), + {in_slices[i], wt_slices[i]}); + out_slices.push_back(out_slice); + } + return concatenate(out_slices, -1, s); + } } array quantized_matmul( diff --git a/python/mlx/nn/layers/convolution.py b/python/mlx/nn/layers/convolution.py index c6928e188..3b3679825 100644 --- a/python/mlx/nn/layers/convolution.py +++ b/python/mlx/nn/layers/convolution.py @@ -23,6 +23,9 @@ class Conv1d(Module): Default: 1. padding (int, optional): How many positions to 0-pad the input with. Default: 0. + dilation (int, optional): The size of the dilation. Default: 1. + groups (int, optional): The number of groups to split the input. + Default: 1. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -34,6 +37,8 @@ def __init__( kernel_size: int, stride: int = 1, padding: int = 0, + dilation: int = 1, + groups: int = 1, bias: bool = True, ): super().__init__() @@ -49,16 +54,21 @@ def __init__( self.padding = padding self.stride = stride + self.dilation = dilation + self.groups = groups def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " - f"padding={self.padding}, bias={'bias' in self}" + f"padding={self.padding}, dilation={self.dilation}, " + f"groups={self.groups} bias={'bias' in self}" ) def __call__(self, x): - y = mx.conv1d(x, self.weight, self.stride, self.padding) + y = mx.conv1d( + x, self.weight, self.stride, self.padding, self.dilation, self.groups + ) if "bias" in self: y = y + self.bias return y @@ -81,6 +91,10 @@ class Conv2d(Module): applying the filter. Default: 1. padding (int or tuple, optional): How many positions to 0-pad the input with. Default: 0. + dilation (int or tuple, optional): The size of the dilation. + Default: 1. + groups (int, optional): The number of groups to split the input. + Default: 1. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -92,13 +106,15 @@ def __init__( kernel_size: Union[int, tuple], stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, + dilation: Union[int, tuple] = 1, + groups: int = 1, bias: bool = True, ): super().__init__() - kernel_size, stride, padding = map( + kernel_size, stride, padding, dilation = map( lambda x: (x, x) if isinstance(x, int) else x, - (kernel_size, stride, padding), + (kernel_size, stride, padding, dilation), ) scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1])) self.weight = mx.random.uniform( @@ -111,16 +127,21 @@ def __init__( self.padding = padding self.stride = stride + self.dilation = dilation + self.groups = groups def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " - f"padding={self.padding}, bias={'bias' in self}" + f"padding={self.padding}, dilation={self.dilation}, " + f"groups={self.groups}, bias={'bias' in self}" ) def __call__(self, x): - y = mx.conv2d(x, self.weight, self.stride, self.padding) + y = mx.conv2d( + x, self.weight, self.stride, self.padding, self.dilation, self.groups + ) if "bias" in self: y = y + self.bias return y diff --git a/python/tests/test_load.py b/python/tests/test_load.py index a37ba83a9..fdd42351d 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -75,9 +75,11 @@ def test_save_and_load_safetensors(self): self.test_dir, f"mlx_{dt}_{i}_fs.safetensors" ) save_dict = { - "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) - if dt in ["float32", "float16", "bfloat16"] - else mx.ones(shape, dtype=getattr(mx, dt)) + "test": ( + mx.random.normal(shape=shape, dtype=getattr(mx, dt)) + if dt in ["float32", "float16", "bfloat16"] + else mx.ones(shape, dtype=getattr(mx, dt)) + ) } with open(save_file_mlx, "wb") as f: @@ -104,9 +106,11 @@ def test_save_and_load_gguf(self): self.test_dir, f"mlx_{dt}_{i}_fs.gguf" ) save_dict = { - "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) - if dt in ["float32", "float16", "bfloat16"] - else mx.ones(shape, dtype=getattr(mx, dt)) + "test": ( + mx.random.normal(shape=shape, dtype=getattr(mx, dt)) + if dt in ["float32", "float16", "bfloat16"] + else mx.ones(shape, dtype=getattr(mx, dt)) + ) } mx.save_gguf(save_file_mlx, save_dict) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 7749e159a..af492bbc8 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -627,6 +627,11 @@ def test_conv2d(self): self.assertEqual(y.shape, (4, 3, 3, 8)) self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4) + # 3x3 conv with groups = 3 + c = nn.Conv2d(3, 6, 3, groups=3) + y = c(x) + self.assertEqual(y.shape, (4, 6, 6, 6)) + def test_sequential(self): x = mx.ones((10, 2)) m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))