Skip to content

djphoenix/mlx-ctc

Repository files navigation

MLX CTC

License PyPI - Version PyPI - Status

C++ and Metal extensions for MLX CTC Loss

Library status

Library is passing initial tests and benchmarks.

However, it is still under development, and so called "alpha".

Installation

MLX-CTC available on PyPI. To install, run:

pip install mlx-ctc

To install latest version from GitHub, run:

pip install git+https://github.com/djphoenix/mlx-ctc.git@main

Usage

Python API of MLX CTC Loss is designed to completely mimic pytorch version.

Minimal usage example for MLX:

import mlx.core as mx
import mlx.nn as nn
from mlx_ctc import ctc_loss

# Target are to be padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
S = 30      # Target sequence length of longest target in batch (padding length)
S_min = 10  # Minimum target length, for demonstration purposes

# Initialize random batch of input vectors, for *size = (T,N,C)
input = nn.log_softmax(mx.random.normal((T, N, C)), 2)

# Initialize random batch of targets (0 = blank, 1:C = classes)
target = mx.random.randint(1, C, shape=(N, S), dtype=mx.uint32)
input_lengths = mx.full((N,), T, dtype=mx.uint32)
target_lengths = mx.random.randint(S_min, S, shape=(N,), dtype=mx.uint32)

# Make function that returns loss and gradient
def ctc_loss_mean(i,t,il,tl):
  return (ctc_loss(i,t,il,tl)/tl).mean()
ctc_loss_grad_fn = mx.value_and_grad(ctc_loss_mean)

# Calculate loss and gradient in single call
loss, grad = ctc_loss_grad_fn(input, target, input_lengths, target_lengths)
mx.eval(loss, grad)

print('Loss:', loss.item())
print('Gradient shape:', grad.shape)

Benchmarks

To run benchmark on your machine, use:

python tests/benchmark.py

It will output table with MB/s rates for CPU and GPU runs, compared against PyTorch CPU rate.

Example output (MNWA3T/A Mac14,6):

---------------------------------------------------------------------
| Shape (TxBxCxS)       | Torch MB/s | MLX CPU MB/s | MLX GPU MB/s  |
---------------------------------------------------------------------
|   64 x 128 x 32 x  16 |     174.04 |  241.17 1.4x |  219.95  1.3x |
|  128 x 128 x 32 x  32 |      96.62 |  129.52 1.3x |  702.22  7.3x |
|  256 x 128 x 32 x  64 |      52.67 |   68.18 1.3x |  622.79 11.8x |
|  512 x 128 x 32 x 128 |      25.62 |   34.33 1.3x |  657.06 25.6x |
| 1024 x 128 x 32 x 256 |      13.46 |   18.55 1.4x |  573.43 42.6x |
---------------------------------------------------------------------
|  128 x  32 x 32 x  32 |      99.31 |  139.19 1.4x |  212.03  2.1x |
|  128 x  64 x 32 x  32 |      94.13 |  128.56 1.4x |  421.05  4.5x |
|  128 x 128 x 32 x  32 |      97.42 |  130.74 1.3x |  733.86  7.5x |
|  128 x 256 x 32 x  32 |      94.30 |  125.29 1.3x | 1056.41 11.2x |
|  128 x 512 x 32 x  32 |      95.34 |  123.40 1.3x | 1390.39 14.6x |
---------------------------------------------------------------------
|  128 x 128 x  8 x  32 |      26.00 |   36.64 1.4x |  178.84  6.9x |
|  128 x 128 x 16 x  32 |      52.30 |   68.91 1.3x |  367.45  7.0x |
|  128 x 128 x 32 x  32 |      97.42 |  130.53 1.3x |  720.66  7.4x |
|  128 x 128 x 48 x  32 |     134.16 |  183.62 1.4x |  998.70  7.4x |
|  128 x 128 x 64 x  32 |     168.65 |  231.30 1.4x | 1366.02  8.1x |
---------------------------------------------------------------------
|  256 x 128 x 32 x  16 |     168.38 |  208.64 1.2x |  923.31  5.5x |
|  256 x 128 x 32 x  24 |     116.98 |  144.75 1.2x |  952.47  8.1x |
|  256 x 128 x 32 x  32 |      93.67 |  115.91 1.2x |  914.53  9.8x |
|  256 x 128 x 32 x  48 |      66.15 |   81.21 1.2x |  675.61 10.2x |
|  256 x 128 x 32 x  64 |      53.71 |   70.10 1.3x |  590.71 11.0x |
---------------------------------------------------------------------

Please note that MLX does not support multithreading, so we use torch.set_num_threads(1) to make comparison fair.

TODO

  • Optimize code more

Credits

Thanks pytorch source for reference implementation, that used for initial CTC Loss development.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

No packages published