Skip to content

Commit

Permalink
Add MLX inference support (#41)
Browse files Browse the repository at this point in the history
*Issue #, if available:* #28

*Description of changes:* This PR adds MLX inference support.

## Summary of changes
- Update `pyproject.toml` with`mlx` dependencies.
- Create `chronos_mlx` package which will hosts all mlx inference stuff.
- All classes from `main:src/chronos/chronos.py` are copy-pasted into
`mlx:src/chronos_mlx/chronos.py` and modified to use numpy and mlx
arrays instead. Note that the reason for using numpy arrays as input and
output is that mlx doesn't support some operations that are required for
input and output transform.
- MLX implementation of T5 is in `src/chronos_mlx/t5.py`. It has been
adapted from
[ml-explore/mlx-examples](https://github.com/ml-explore/mlx-examples/blob/b8a348c1b8df4433cfacb9adbeb89b8aa3979ab2/t5/t5.py)
with the following main modifications:
      - Added support for attention mask.
      - Added support for top_k and top_p sampling.
- `src/chronos_mlx/translate.py` translates weights from a torch HF
model to mlx.
- Add `THIRD-PARTY-LICENSES.txt` for third party code from
`mlx-examples`.
- Add tests and CI for `mlx` version. 

## Sample inference code

```py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from chronos_mlx import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",
    dtype="bfloat16",
)

df = pd.read_csv(
    "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
)

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = df["#Passengers"].values
prediction_length = 12
forecast = pipeline.predict(
    context, prediction_length
)  # shape [num_series, num_samples, prediction_length]

# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
plt.plot(forecast_index, median, color="tomato", label="median forecast")
plt.fill_between(
    forecast_index,
    low,
    high,
    color="tomato",
    alpha=0.3,
    label="80% prediction interval",
)
plt.legend()
plt.grid()
plt.show()

```

## Benchmark


![benchmark](https://github.com/amazon-science/chronos-forecasting/assets/4028948/ee5d1b17-d33e-473c-aa7a-55dbe1059b9c)


```py
import timeit

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from gluonts.dataset.repository import get_dataset
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import SampleForecast
from tqdm.auto import tqdm

from chronos import ChronosPipeline as ChronosPipelineTorch
from chronos_mlx import ChronosPipeline as ChronosPipelineMLX


def benchmark_torch_model(
    pipeline: ChronosPipelineTorch,
    gluonts_dataset: str = "m4_hourly",
    batch_size: int = 32,
):
    dataset = get_dataset(gluonts_dataset)
    prediction_length = dataset.metadata.prediction_length
    _, test_template = split(dataset.test, offset=-prediction_length)
    test_data = test_template.generate_instances(prediction_length)
    test_data_input = list(test_data.input)

    start_time = timeit.default_timer()
    forecasts = []
    for idx in tqdm(range(0, len(test_data_input), batch_size)):
        batch = [
            torch.tensor(item["target"])
            for item in test_data_input[idx : idx + batch_size]
        ]
        batch_forecasts = pipeline.predict(batch, prediction_length)
        forecasts.append(batch_forecasts)
    forecasts = torch.cat(forecasts)
    end_time = timeit.default_timer()

    print(f"Inference time: {end_time-start_time:.2f}s")

    results_df = evaluate_forecasts(
        forecasts=[
            SampleForecast(fcst.numpy(), start_date=label["start"])
            for fcst, label in zip(forecasts, test_data.label)
        ],
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
    )
    results_df["inference_time"] = end_time - start_time
    return results_df


def benchmark_mlx_model(
    pipeline: ChronosPipelineMLX,
    gluonts_dataset: str = "m4_hourly",
    batch_size: int = 32,
):
    dataset = get_dataset(gluonts_dataset)
    prediction_length = dataset.metadata.prediction_length
    _, test_template = split(dataset.test, offset=-prediction_length)
    test_data = test_template.generate_instances(prediction_length)
    test_data_input = list(test_data.input)

    start_time = timeit.default_timer()
    forecasts = []
    for idx in tqdm(range(0, len(test_data_input), batch_size)):
        batch = [item["target"] for item in test_data_input[idx : idx + batch_size]]
        batch_forecasts = pipeline.predict(batch, prediction_length)
        forecasts.append(batch_forecasts)
    forecasts = np.concatenate(forecasts)
    end_time = timeit.default_timer()

    print(f"Inference time: {end_time-start_time:.2f}s")

    results_df = evaluate_forecasts(
        forecasts=[
            SampleForecast(fcst, start_date=label["start"])
            for fcst, label in zip(forecasts, test_data.label)
        ],
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
    )
    results_df["inference_time"] = end_time - start_time
    return results_df


def main(
    version: str = "cpu",  # cpu, mps, mlx
    dtype: str = "bfloat16",
    gluonts_dataset: str = "australian_electricity_demand",
    model_name: str = "amazon/chronos-t5-small",
    batch_size: int = 4,
):
    if version == "cpu" or version == "mps":
        pipeline = ChronosPipelineTorch.from_pretrained(
            model_name,
            device_map=version,
            torch_dtype=getattr(torch, dtype),
        )
        benchmark_fn = benchmark_torch_model
    else:
        pipeline = ChronosPipelineMLX.from_pretrained(model_name, dtype=dtype)
        benchmark_fn = benchmark_mlx_model

    result_df = benchmark_fn(
        pipeline, gluonts_dataset=gluonts_dataset, batch_size=batch_size
    )
    result_df["model"] = model_name
    return result_df


if __name__ == "__main__":
    gluonts_dataset: str = "m4_hourly"
    model_name: str = "amazon/chronos-t5-mini"
    batch_size: int = 8
    dfs = []
    for version in ["cpu", "mps", "mlx"]:
        for dtype in ["float32"]:
            try:
                df = main(
                    version=version,
                    dtype=dtype,
                    model_name=model_name,
                    gluonts_dataset=gluonts_dataset,
                    batch_size=batch_size,
                )
                df["version"] = version
                df["dtype"] = dtype
                dfs.append(df)
            except TypeError:
                pass

    result_df = pd.concat(dfs).reset_index(drop=True)
    result_df.to_csv("benchmark.csv", index=False)

    result_df["version"] = result_df["version"].map(
        {"cpu": "Torch (CPU)", "mps": "Torch (MPS)", "mlx": "MLX"}
    )
    fig = plt.figure(figsize=(8, 5))
    g = sns.barplot(
        data=result_df,
        x="dtype",
        y="inference_time",
        hue="version",
        alpha=0.6,
    )
    plt.ylabel("Inference Time (on M1 Pro)")
    plt.title(f"{model_name} inference times on {gluonts_dataset} dataset")
    plt.savefig("benchmark.png", dpi=200)

```

## TODOs:
- [x] Implement `top_p` sampling.
- [x] Add tests.
- [x] Add CI.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.com>
  • Loading branch information
abdulfatir and Abdul Fatir Ansari authored Apr 8, 2024
1 parent 2042779 commit 159ea36
Show file tree
Hide file tree
Showing 10 changed files with 830 additions and 398 deletions.
45 changes: 12 additions & 33 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,23 @@ name: CI
on: [push, pull_request]

jobs:
type-check:
test-mlx:
strategy:
max-parallel: 4
fail-fast: false
matrix:
python-version: ["3.11"]
platform: [ubuntu-latest]
python-version: ['3.11']
platform: [macos-14]

runs-on: ${{ matrix.platform }}

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[typecheck]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Type checks with mypy
run: mypy src test

test:
strategy:
max-parallel: 4
fail-fast: false
matrix:
python-version: ["3.11"]
platform: [ubuntu-latest]

runs-on: ${{ matrix.platform }}

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[test]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Test with pytest
run: pytest
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[test]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Test with pytest
run: pytest test/
27 changes: 13 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Chronos: Learning the Language of Time Series
# [🧪 MLX Version] Chronos: Learning the Language of Time Series

> [!IMPORTANT]
> This is the **experimental** MLX version of Chronos for Apple Silicon Macs. Please use the `main` branch for the stable PyTorch version.
Chronos is a family of **pretrained time series forecasting models** based on language model architectures. A time series is transformed into a sequence of tokens via scaling and quantization, and a language model is trained on these tokens using the cross-entropy loss. Once trained, probabilistic forecasts are obtained by sampling multiple future trajectories given the historical context. Chronos models have been trained on a large corpus of publicly available time series data, as well as synthetic data generated using Gaussian processes.

Expand Down Expand Up @@ -28,10 +31,10 @@ The models in this repository are based on the [T5 architecture](https://arxiv.o

## Usage

To perform inference with Chronos models, install this package by running:
To perform inference with Chronos models on Apple Silcon devices, install this package by running:

```
pip install git+https://github.com/amazon-science/chronos-forecasting.git
pip install git+https://github.com/amazon-science/chronos-forecasting.git@mlx
```

### Forecasting
Expand All @@ -43,20 +46,18 @@ A minimal example showing how to perform forecasting using Chronos models:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline
from chronos_mlx import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
device_map="cuda",
torch_dtype=torch.bfloat16,
dtype="bfloat16",
)

df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = torch.tensor(df["#Passengers"])
context = df["#Passengers"].values
prediction_length = 12
forecast = pipeline.predict(
context,
Expand All @@ -69,7 +70,7 @@ forecast = pipeline.predict(

# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
Expand All @@ -86,20 +87,18 @@ A minimal example showing how to extract encoder embeddings from Chronos models:

```python
import pandas as pd
import torch
from chronos import ChronosPipeline
from chronos_mlx import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
device_map="cuda",
torch_dtype=torch.bfloat16,
dtype="bfloat16",
)

df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = torch.tensor(df["#Passengers"])
context = df["#Passengers"].values
embeddings, tokenizer_state = pipeline.embed(context)
```

Expand Down
23 changes: 23 additions & 0 deletions THIRD-PARTY-LICENSES.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
** mlx-examples; version b8a348c -- https://github.com/ml-explore/mlx-examples

MIT License

Copyright © 2023 Apple Inc.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
[project]
name = "chronos"
name = "chronos_mlx"
version = "1.1.0"
requires-python = ">=3.8"
license = { file = "LICENSE" }
dependencies = [
"torch~=2.1", # package was tested on 2.2
"transformers~=4.31",
"accelerate",
"mlx~=0.9.0"
]

[project.optional-dependencies]
Expand Down
File renamed without changes.
Loading

0 comments on commit 159ea36

Please sign in to comment.