Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
*Issue #, if available:* #28 *Description of changes:* This PR adds MLX inference support. ## Summary of changes - Update `pyproject.toml` with`mlx` dependencies. - Create `chronos_mlx` package which will hosts all mlx inference stuff. - All classes from `main:src/chronos/chronos.py` are copy-pasted into `mlx:src/chronos_mlx/chronos.py` and modified to use numpy and mlx arrays instead. Note that the reason for using numpy arrays as input and output is that mlx doesn't support some operations that are required for input and output transform. - MLX implementation of T5 is in `src/chronos_mlx/t5.py`. It has been adapted from [ml-explore/mlx-examples](https://github.com/ml-explore/mlx-examples/blob/b8a348c1b8df4433cfacb9adbeb89b8aa3979ab2/t5/t5.py) with the following main modifications: - Added support for attention mask. - Added support for top_k and top_p sampling. - `src/chronos_mlx/translate.py` translates weights from a torch HF model to mlx. - Add `THIRD-PARTY-LICENSES.txt` for third party code from `mlx-examples`. - Add tests and CI for `mlx` version. ## Sample inference code ```py import matplotlib.pyplot as plt import numpy as np import pandas as pd from chronos_mlx import ChronosPipeline pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-small", dtype="bfloat16", ) df = pd.read_csv( "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv" ) # context must be either a 1D tensor, a list of 1D tensors, # or a left-padded 2D tensor with batch as the first dimension context = df["#Passengers"].values prediction_length = 12 forecast = pipeline.predict( context, prediction_length ) # shape [num_series, num_samples, prediction_length] # visualize the forecast forecast_index = range(len(df), len(df) + prediction_length) low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0) plt.figure(figsize=(8, 4)) plt.plot(df["#Passengers"], color="royalblue", label="historical data") plt.plot(forecast_index, median, color="tomato", label="median forecast") plt.fill_between( forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval", ) plt.legend() plt.grid() plt.show() ``` ## Benchmark ![benchmark](https://github.com/amazon-science/chronos-forecasting/assets/4028948/ee5d1b17-d33e-473c-aa7a-55dbe1059b9c) ```py import timeit import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import torch from gluonts.dataset.repository import get_dataset from gluonts.dataset.split import split from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss from gluonts.model.evaluation import evaluate_forecasts from gluonts.model.forecast import SampleForecast from tqdm.auto import tqdm from chronos import ChronosPipeline as ChronosPipelineTorch from chronos_mlx import ChronosPipeline as ChronosPipelineMLX def benchmark_torch_model( pipeline: ChronosPipelineTorch, gluonts_dataset: str = "m4_hourly", batch_size: int = 32, ): dataset = get_dataset(gluonts_dataset) prediction_length = dataset.metadata.prediction_length _, test_template = split(dataset.test, offset=-prediction_length) test_data = test_template.generate_instances(prediction_length) test_data_input = list(test_data.input) start_time = timeit.default_timer() forecasts = [] for idx in tqdm(range(0, len(test_data_input), batch_size)): batch = [ torch.tensor(item["target"]) for item in test_data_input[idx : idx + batch_size] ] batch_forecasts = pipeline.predict(batch, prediction_length) forecasts.append(batch_forecasts) forecasts = torch.cat(forecasts) end_time = timeit.default_timer() print(f"Inference time: {end_time-start_time:.2f}s") results_df = evaluate_forecasts( forecasts=[ SampleForecast(fcst.numpy(), start_date=label["start"]) for fcst, label in zip(forecasts, test_data.label) ], test_data=test_data, metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))], ) results_df["inference_time"] = end_time - start_time return results_df def benchmark_mlx_model( pipeline: ChronosPipelineMLX, gluonts_dataset: str = "m4_hourly", batch_size: int = 32, ): dataset = get_dataset(gluonts_dataset) prediction_length = dataset.metadata.prediction_length _, test_template = split(dataset.test, offset=-prediction_length) test_data = test_template.generate_instances(prediction_length) test_data_input = list(test_data.input) start_time = timeit.default_timer() forecasts = [] for idx in tqdm(range(0, len(test_data_input), batch_size)): batch = [item["target"] for item in test_data_input[idx : idx + batch_size]] batch_forecasts = pipeline.predict(batch, prediction_length) forecasts.append(batch_forecasts) forecasts = np.concatenate(forecasts) end_time = timeit.default_timer() print(f"Inference time: {end_time-start_time:.2f}s") results_df = evaluate_forecasts( forecasts=[ SampleForecast(fcst, start_date=label["start"]) for fcst, label in zip(forecasts, test_data.label) ], test_data=test_data, metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))], ) results_df["inference_time"] = end_time - start_time return results_df def main( version: str = "cpu", # cpu, mps, mlx dtype: str = "bfloat16", gluonts_dataset: str = "australian_electricity_demand", model_name: str = "amazon/chronos-t5-small", batch_size: int = 4, ): if version == "cpu" or version == "mps": pipeline = ChronosPipelineTorch.from_pretrained( model_name, device_map=version, torch_dtype=getattr(torch, dtype), ) benchmark_fn = benchmark_torch_model else: pipeline = ChronosPipelineMLX.from_pretrained(model_name, dtype=dtype) benchmark_fn = benchmark_mlx_model result_df = benchmark_fn( pipeline, gluonts_dataset=gluonts_dataset, batch_size=batch_size ) result_df["model"] = model_name return result_df if __name__ == "__main__": gluonts_dataset: str = "m4_hourly" model_name: str = "amazon/chronos-t5-mini" batch_size: int = 8 dfs = [] for version in ["cpu", "mps", "mlx"]: for dtype in ["float32"]: try: df = main( version=version, dtype=dtype, model_name=model_name, gluonts_dataset=gluonts_dataset, batch_size=batch_size, ) df["version"] = version df["dtype"] = dtype dfs.append(df) except TypeError: pass result_df = pd.concat(dfs).reset_index(drop=True) result_df.to_csv("benchmark.csv", index=False) result_df["version"] = result_df["version"].map( {"cpu": "Torch (CPU)", "mps": "Torch (MPS)", "mlx": "MLX"} ) fig = plt.figure(figsize=(8, 5)) g = sns.barplot( data=result_df, x="dtype", y="inference_time", hue="version", alpha=0.6, ) plt.ylabel("Inference Time (on M1 Pro)") plt.title(f"{model_name} inference times on {gluonts_dataset} dataset") plt.savefig("benchmark.png", dpi=200) ``` ## TODOs: - [x] Implement `top_p` sampling. - [x] Add tests. - [x] Add CI. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.com>
- Loading branch information