diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 29693c5..865b970 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,44 +3,23 @@ name: CI on: [push, pull_request] jobs: - type-check: + test-mlx: strategy: max-parallel: 4 fail-fast: false matrix: - python-version: ["3.11"] - platform: [ubuntu-latest] + python-version: ['3.11'] + platform: [macos-14] runs-on: ${{ matrix.platform }} steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: pip install ".[typecheck]" -f https://download.pytorch.org/whl/cpu/torch_stable.html - - name: Type checks with mypy - run: mypy src test - - test: - strategy: - max-parallel: 4 - fail-fast: false - matrix: - python-version: ["3.11"] - platform: [ubuntu-latest] - - runs-on: ${{ matrix.platform }} - - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: pip install ".[test]" -f https://download.pytorch.org/whl/cpu/torch_stable.html - - name: Test with pytest - run: pytest + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: pip install ".[test]" -f https://download.pytorch.org/whl/cpu/torch_stable.html + - name: Test with pytest + run: pytest test/ diff --git a/README.md b/README.md index cae2891..c457967 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ -# Chronos: Learning the Language of Time Series +# [🧪 MLX Version] Chronos: Learning the Language of Time Series + +> [!IMPORTANT] +> This is the **experimental** MLX version of Chronos for Apple Silicon Macs. Please use the `main` branch for the stable PyTorch version. Chronos is a family of **pretrained time series forecasting models** based on language model architectures. A time series is transformed into a sequence of tokens via scaling and quantization, and a language model is trained on these tokens using the cross-entropy loss. Once trained, probabilistic forecasts are obtained by sampling multiple future trajectories given the historical context. Chronos models have been trained on a large corpus of publicly available time series data, as well as synthetic data generated using Gaussian processes. @@ -28,10 +31,10 @@ The models in this repository are based on the [T5 architecture](https://arxiv.o ## Usage -To perform inference with Chronos models, install this package by running: +To perform inference with Chronos models on Apple Silcon devices, install this package by running: ``` -pip install git+https://github.com/amazon-science/chronos-forecasting.git +pip install git+https://github.com/amazon-science/chronos-forecasting.git@mlx ``` ### Forecasting @@ -43,20 +46,18 @@ A minimal example showing how to perform forecasting using Chronos models: import matplotlib.pyplot as plt import numpy as np import pandas as pd -import torch -from chronos import ChronosPipeline +from chronos_mlx import ChronosPipeline pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-small", - device_map="cuda", - torch_dtype=torch.bfloat16, + dtype="bfloat16", ) df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv") # context must be either a 1D tensor, a list of 1D tensors, # or a left-padded 2D tensor with batch as the first dimension -context = torch.tensor(df["#Passengers"]) +context = df["#Passengers"].values prediction_length = 12 forecast = pipeline.predict( context, @@ -69,7 +70,7 @@ forecast = pipeline.predict( # visualize the forecast forecast_index = range(len(df), len(df) + prediction_length) -low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) +low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0) plt.figure(figsize=(8, 4)) plt.plot(df["#Passengers"], color="royalblue", label="historical data") @@ -86,20 +87,18 @@ A minimal example showing how to extract encoder embeddings from Chronos models: ```python import pandas as pd -import torch -from chronos import ChronosPipeline +from chronos_mlx import ChronosPipeline pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-small", - device_map="cuda", - torch_dtype=torch.bfloat16, + dtype="bfloat16", ) df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv") # context must be either a 1D tensor, a list of 1D tensors, # or a left-padded 2D tensor with batch as the first dimension -context = torch.tensor(df["#Passengers"]) +context = df["#Passengers"].values embeddings, tokenizer_state = pipeline.embed(context) ``` diff --git a/THIRD-PARTY-LICENSES.txt b/THIRD-PARTY-LICENSES.txt new file mode 100644 index 0000000..1afd20e --- /dev/null +++ b/THIRD-PARTY-LICENSES.txt @@ -0,0 +1,23 @@ +** mlx-examples; version b8a348c -- https://github.com/ml-explore/mlx-examples + +MIT License + +Copyright © 2023 Apple Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 61cafa7..f6de45c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "chronos" +name = "chronos_mlx" version = "1.1.0" requires-python = ">=3.8" license = { file = "LICENSE" } @@ -7,6 +7,7 @@ dependencies = [ "torch~=2.1", # package was tested on 2.2 "transformers~=4.31", "accelerate", + "mlx~=0.9.0" ] [project.optional-dependencies] diff --git a/src/chronos/__init__.py b/src/chronos_mlx/__init__.py similarity index 100% rename from src/chronos/__init__.py rename to src/chronos_mlx/__init__.py diff --git a/src/chronos/chronos.py b/src/chronos_mlx/chronos.py similarity index 62% rename from src/chronos/chronos.py rename to src/chronos_mlx/chronos.py index cc1d3fb..26f4cc7 100644 --- a/src/chronos/chronos.py +++ b/src/chronos_mlx/chronos.py @@ -3,18 +3,18 @@ import warnings from dataclasses import dataclass +from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, Union -import chronos -import torch -import torch.nn as nn -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoModelForSeq2SeqLM, - GenerationConfig, - PreTrainedModel, -) +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.utils import tree_map, tree_unflatten +from transformers import T5Config + +import chronos_mlx +from chronos_mlx.t5 import T5 +from chronos_mlx.translate import translate_weights @dataclass @@ -46,13 +46,13 @@ def __post_init__(self): ), f"Special token id's must be smaller than {self.n_special_tokens=}" def create_tokenizer(self) -> "ChronosTokenizer": - class_ = getattr(chronos, self.tokenizer_class) + class_ = getattr(chronos_mlx, self.tokenizer_class) return class_(**self.tokenizer_kwargs, config=self) class ChronosTokenizer: """ - A ``ChronosTokenizer`` definines how time series are mapped into token IDs + A ``ChronosTokenizer`` defines how time series are mapped into token IDs and back. For details, see the ``input_transform`` and ``output_transform`` methods, @@ -60,27 +60,27 @@ class ChronosTokenizer: """ def input_transform( - self, context: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, Any]: + self, context: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, Any]: """ Turn a batch of time series into token IDs, attention map, and scale. Parameters ---------- context - A tensor shaped (batch_size, time_length), containing the - timeseries to forecast. Use left-padding with ``torch.nan`` + A numpy array shaped (batch_size, time_length), containing the + timeseries to forecast. Use left-padding with ``np.nan`` to align time series of different lengths. Returns ------- token_ids - A tensor of integers, shaped (batch_size, time_length + 1) + A numpy array of integers, shaped (batch_size, time_length + 1) if ``config.use_eos_token`` and (batch_size, time_length) otherwise, containing token IDs for the input series. attention_mask - A boolean tensor, same shape as ``token_ids``, indicating - which input observations are not ``torch.nan`` (i.e. not + A boolean numpy array, same shape as ``token_ids``, indicating + which input observations are not ``np.nan`` (i.e. not missing nor padding). tokenizer_state An object that will be passed to ``output_transform``. @@ -89,16 +89,14 @@ def input_transform( """ raise NotImplementedError() - def output_transform( - self, samples: torch.Tensor, tokenizer_state: Any - ) -> torch.Tensor: + def output_transform(self, samples: np.ndarray, tokenizer_state: Any) -> np.ndarray: """ Turn a batch of sample token IDs into real values. Parameters ---------- samples - A tensor of integers, shaped (batch_size, num_samples, time_length), + A numpy array of integers, shaped (batch_size, num_samples, time_length), containing token IDs of sample trajectories. tokenizer_state An object returned by ``input_transform`` containing @@ -108,7 +106,7 @@ def output_transform( Returns ------- forecasts - A real tensor, shaped (batch_size, num_samples, time_length), + A real numpy array, shaped (batch_size, num_samples, time_length), containing forecasted sample paths. """ raise NotImplementedError() @@ -119,70 +117,60 @@ def __init__( self, low_limit: float, high_limit: float, config: ChronosConfig ) -> None: self.config = config - self.centers = torch.linspace( + self.centers = np.linspace( low_limit, high_limit, config.n_tokens - config.n_special_tokens - 1, ) - self.boundaries = torch.concat( + self.boundaries = np.concatenate( ( - torch.tensor([-1e20], device=self.centers.device), + np.array([-1e20]), (self.centers[1:] + self.centers[:-1]) / 2, - torch.tensor([1e20], device=self.centers.device), + np.array([1e20]), ) ) def input_transform( - self, context: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, context: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: batch_size, length = context.shape if length > self.config.context_length: context = context[..., -self.config.context_length :] - attention_mask = ~torch.isnan(context) - scale = torch.nansum( - torch.abs(context) * attention_mask, dim=-1 - ) / torch.nansum(attention_mask, dim=-1) + attention_mask = ~np.isnan(context) + scale = np.nansum(np.abs(context) * attention_mask, axis=-1) / np.nansum( + attention_mask, axis=-1 + ) scale[~(scale > 0)] = 1.0 - scaled_context = context / scale.unsqueeze(dim=-1) + scaled_context = context / scale[..., np.newaxis] token_ids = ( - torch.bucketize( - input=scaled_context, - boundaries=self.boundaries, - # buckets are open to the right, see: - # https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize - right=True, - ) + np.digitize(scaled_context, bins=self.boundaries) + self.config.n_special_tokens ) token_ids[~attention_mask] = self.config.pad_token_id if self.config.use_eos_token: - eos_tokens = torch.full( - (batch_size, 1), fill_value=self.config.eos_token_id - ) - token_ids = torch.concat((token_ids, eos_tokens), dim=1) - eos_mask = torch.full((batch_size, 1), fill_value=True) - attention_mask = torch.concat((attention_mask, eos_mask), dim=1) + eos_tokens = np.full((batch_size, 1), fill_value=self.config.eos_token_id) + token_ids = np.concatenate((token_ids, eos_tokens), axis=1) + eos_mask = np.full((batch_size, 1), fill_value=True) + attention_mask = np.concatenate((attention_mask, eos_mask), axis=1) return token_ids, attention_mask, scale - def output_transform( - self, samples: torch.Tensor, scale: torch.Tensor - ) -> torch.Tensor: - scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1) - indices = torch.clamp( + def output_transform(self, samples: np.ndarray, scale: np.ndarray) -> np.ndarray: + scale_unsqueezed = scale[..., np.newaxis, np.newaxis] + indices = np.clip( samples - self.config.n_special_tokens, - min=0, - max=len(self.centers) - 1, + a_min=0, + a_max=len(self.centers) - 1, ) return self.centers[indices] * scale_unsqueezed class ChronosModel(nn.Module): """ - A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers`` + A ``ChronosModel`` wraps a ``T5`` object from ``chronos.mlx.t5`` and uses it to predict sample paths for time series tokens. Parameters @@ -190,22 +178,21 @@ class ChronosModel(nn.Module): config The configuration to use. model - The pre-trained model to use. + The pretrained T5 model to use. """ - def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None: + def __init__(self, config: ChronosConfig, model: T5) -> None: super().__init__() + assert config.model_type == "seq2seq" and isinstance( + model, T5 + ), "Only the T5 model is currently supported in MLX" self.config = config self.model = model - @property - def device(self): - return self.model.device - def encode( self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, + input_ids: np.ndarray, + attention_mask: np.ndarray, ): """ Extract the encoder embedding for the given token sequences. @@ -213,35 +200,33 @@ def encode( Parameters ---------- input_ids - Tensor of indices of input sequence tokens in the vocabulary + Array of indices of input sequence tokens in the vocabulary with shape (batch_size, sequence_length). attention_mask - A mask tensor of the same shape as input_ids to avoid attending + A mask array of the same shape as input_ids to avoid attending on padding or missing tokens. Returns ------- embedding - A tensor of encoder embeddings with shape + An array of encoder embeddings with shape (batch_size, sequence_length, d_model). """ assert ( self.config.model_type == "seq2seq" ), "Encoder embeddings are only supported for encoder-decoder models" - return self.model.encoder( - input_ids=input_ids, attention_mask=attention_mask - ).last_hidden_state + return self.model.encode(inputs=input_ids, mask=attention_mask) - def forward( + def __call__( self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, + input_ids: mx.array, + attention_mask: mx.array, prediction_length: Optional[int] = None, num_samples: Optional[int] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - ) -> torch.Tensor: + ) -> mx.array: """ Predict future sample tokens for the given token sequences. @@ -253,7 +238,7 @@ def forward( Returns ------- samples - A tensor of integers, shaped (batch_size, num_samples, time_length), + A numpy array of integers, shaped (batch_size, num_samples, time_length), containing forecasted sample paths. """ if prediction_length is None: @@ -270,40 +255,31 @@ def forward( preds = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, - generation_config=GenerationConfig( - min_new_tokens=prediction_length, - max_new_tokens=prediction_length, - do_sample=True, - num_return_sequences=num_samples, - eos_token_id=self.config.eos_token_id, - pad_token_id=self.config.pad_token_id, - temperature=temperature, - top_k=top_k, - top_p=top_p, - ), + min_new_tokens=prediction_length, + max_new_tokens=prediction_length, + do_sample=True, + num_return_sequences=num_samples, + eos_token_id=self.config.eos_token_id, + pad_token_id=self.config.pad_token_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, ) - if self.config.model_type == "seq2seq": - preds = preds[..., 1:] # remove the decoder start token - else: - assert self.config.model_type == "causal" - assert preds.size(-1) == input_ids.size(-1) + prediction_length - preds = preds[..., -prediction_length:] + preds = preds[..., 1:] # remove the decoder start token - return preds.reshape(input_ids.size(0), num_samples, -1) + return preds.reshape(input_ids.shape[0], num_samples, -1) -def left_pad_and_stack_1D(tensors: List[torch.Tensor]): - max_len = max(len(c) for c in tensors) +def left_pad_and_stack_1D(arrays: List[np.ndarray]): + max_len = max(len(c) for c in arrays) padded = [] - for c in tensors: - assert isinstance(c, torch.Tensor) + for c in arrays: + assert isinstance(c, np.ndarray) assert c.ndim == 1 - padding = torch.full( - size=(max_len - len(c),), fill_value=torch.nan, device=c.device - ) - padded.append(torch.concat((padding, c), dim=-1)) - return torch.stack(padded) + padding = np.full(shape=(max_len - len(c),), fill_value=np.nan) + padded.append(np.concatenate((padding, c), axis=-1)) + return np.stack(padded) class ChronosPipeline: @@ -330,21 +306,20 @@ def __init__(self, tokenizer, model): self.model = model def _prepare_and_validate_context( - self, context: Union[torch.Tensor, List[torch.Tensor]] - ): + self, context: Union[np.ndarray, List[np.ndarray]] + ) -> np.ndarray: if isinstance(context, list): context = left_pad_and_stack_1D(context) - assert isinstance(context, torch.Tensor) + assert isinstance(context, np.ndarray) if context.ndim == 1: - context = context.unsqueeze(0) + context = context[np.newaxis, ...] assert context.ndim == 2 return context - @torch.no_grad() def embed( - self, context: Union[torch.Tensor, List[torch.Tensor]] - ) -> Tuple[torch.Tensor, Any]: + self, context: Union[np.ndarray, List[np.ndarray]] + ) -> Tuple[np.ndarray, Any]: """ Get encoder embeddings for the given time series. @@ -367,36 +342,36 @@ def embed( or the length of the longest time series, if a list of 1D tensors was provided, and the extra 1 is for EOS. """ - context_tensor = self._prepare_and_validate_context(context=context) + context_array = self._prepare_and_validate_context(context=context) token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform( - context_tensor + context_array ) embeddings = self.model.encode( - input_ids=token_ids.to(self.model.device), - attention_mask=attention_mask.to(self.model.device), - ).cpu() - return embeddings, tokenizer_state + input_ids=mx.array(token_ids), + attention_mask=mx.array(attention_mask), + ) + return np.array(embeddings.astype(mx.float32)), tokenizer_state def predict( self, - context: Union[torch.Tensor, List[torch.Tensor]], + context: Union[np.ndarray, List[np.ndarray]], prediction_length: Optional[int] = None, num_samples: Optional[int] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, limit_prediction_length: bool = True, - ) -> torch.Tensor: + ) -> np.ndarray: """ Get forecasts for the given time series. Parameters ---------- context - Input series. This is either a 1D tensor, or a list - of 1D tensors, or a 2D tensor whose first dimension + Input series. This is either a 1D numpy array, or a list + of 1D numpy arrays, or a 2D numpy array whose first dimension is batch. In the latter case, use left-padding with - ``torch.nan`` to align series of different lengths. + ``np.nan`` to align series of different lengths. prediction_length Time steps to predict. Defaults to what specified in ``self.model.config``. @@ -421,10 +396,10 @@ def predict( Returns ------- samples - Tensor of sample forecasts, of shape + Numpy array of sample forecasts, of shape (batch_size, num_samples, prediction_length). """ - context_tensor = self._prepare_and_validate_context(context=context) + context_array = self._prepare_and_validate_context(context=context) if prediction_length is None: prediction_length = self.model.config.prediction_length @@ -432,7 +407,7 @@ def predict( if prediction_length > self.model.config.prediction_length: msg = ( f"We recommend keeping prediction length <= {self.model.config.prediction_length}. " - "The quality of longer predictions may degrade since the model is not optimized for it. " + f"The quality of longer predictions may degrade since the model is not optimized for it. " ) if limit_prediction_length: msg += "You can turn off this check by setting `limit_prediction_length=False`." @@ -444,20 +419,19 @@ def predict( while remaining > 0: token_ids, attention_mask, scale = self.tokenizer.input_transform( - context_tensor + context_array ) + token_ids, attention_mask = mx.array(token_ids), mx.array(attention_mask) samples = self.model( - token_ids.to(self.model.device), - attention_mask.to(self.model.device), + token_ids, + attention_mask, min(remaining, self.model.config.prediction_length), num_samples, temperature, top_k, top_p, ) - prediction = self.tokenizer.output_transform( - samples.to(scale.device), scale - ) + prediction = self.tokenizer.output_transform(np.array(samples), scale) predictions.append(prediction) remaining -= prediction.shape[-1] @@ -465,31 +439,44 @@ def predict( if remaining <= 0: break - context_tensor = torch.cat( - [context_tensor, prediction.median(dim=1).values], dim=-1 + context_array = np.concatenate( + [context_array, np.median(prediction, axis=1)], axis=-1 ) - return torch.cat(predictions, dim=-1) + return np.concatenate(predictions, axis=-1) @classmethod - def from_pretrained(cls, *args, **kwargs): + def from_pretrained( + cls, model_name_or_path: Union[str, Path], dtype: str = "float32" + ): """ Load the model, either from a local path or from the HuggingFace Hub. - Supports the same arguments as ``AutoConfig`` and ``AutoModel`` - from ``transformers``. + + Parameters + ---------- + model_name_or_path + Model ID on HuggingFace Hub or local path. + dtype, optional + String denoting the float dtype of the mlx model, + by default "float32" + + Returns + ------- + A ChronosPipeline """ - config = AutoConfig.from_pretrained(*args, **kwargs) + config = T5Config.from_pretrained(model_name_or_path) assert hasattr(config, "chronos_config"), "Not a Chronos config file" + dtype = getattr(mx, dtype) chronos_config = ChronosConfig(**config.chronos_config) - - if chronos_config.model_type == "seq2seq": - inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs) - else: - assert config.model_type == "causal" - inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs) + inner_model = T5(config=config) + weights = translate_weights(model_name_or_path=model_name_or_path, dtype=dtype) + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: p.astype(dtype), weights) + inner_model.update(weights) + mx.eval(inner_model.parameters()) return cls( tokenizer=chronos_config.create_tokenizer(), diff --git a/src/chronos_mlx/t5.py b/src/chronos_mlx/t5.py new file mode 100644 index 0000000..2ec7659 --- /dev/null +++ b/src/chronos_mlx/t5.py @@ -0,0 +1,420 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from ml-explore/mlx-examples: +# https://github.com/ml-explore/mlx-examples/blob/b8a348c1b8df4433cfacb9adbeb89b8aa3979ab2/t5/t5.py +# Modifications: +# - Added support for attention mask. +# - Added support for top_k and top_p sampling. + + +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from transformers import T5Config + + +def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 +): + # Adapted from HuggingFace transformers: + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets + relative_position = mx.abs(relative_position) + else: + relative_position = -mx.minimum( + relative_position, mx.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins + # in positions up to max_distance + scale = (num_buckets - max_exact) / np.log(max_distance / max_exact) + relative_position_if_large = max_exact + ( + mx.log(relative_position.astype(mx.float32) / max_exact) * scale + ).astype(mx.int16) + relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1) + relative_buckets += mx.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + +class RelativePositionBias(nn.Module): + def __init__(self, config: T5Config, bidirectional: bool): + self.bidirectional = bidirectional + self.num_buckets = config.relative_attention_num_buckets + self.max_distance = config.relative_attention_max_distance + self.n_heads = config.num_heads + self.embeddings = nn.Embedding( + config.relative_attention_num_buckets, config.num_heads + ) + + def __call__(self, query_length: int, key_length: int, offset: int = 0): + """Compute binned relative position bias""" + context_position = mx.arange(offset, query_length)[:, None] + memory_position = mx.arange(key_length)[None, :] + + # shape (query_length, key_length) + relative_position = memory_position - context_position + relative_position_bucket = _relative_position_bucket( + relative_position, + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + + # shape (query_length, key_length, num_heads) + values = self.embeddings(relative_position_bucket) + + # shape (num_heads, query_length, key_length) + return values.transpose(2, 0, 1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + inner_dim = config.d_kv * config.num_heads + self.num_heads = config.num_heads + self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False) + + def __call__( + self, + queries: mx.array, + keys: mx.array, + values: mx.array, + mask: Optional[mx.array], + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, _ = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + keys = mx.concatenate([key_cache, keys], axis=3) + values = mx.concatenate([value_cache, values], axis=2) + + # Dimensions are [batch x num heads x sequence x hidden dim] + scores = queries @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.out_proj(values_hat), (keys, values) + + +class DenseActivation(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + mlp_dims = config.d_ff or config.d_model * 4 + self.gated = config.feed_forward_proj.startswith("gated") + if self.gated: + self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False) + else: + self.wi = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wo = nn.Linear(mlp_dims, config.d_model, bias=False) + activation = config.feed_forward_proj.removeprefix("gated-") + if activation == "relu": + self.act = nn.relu + elif activation == "gelu": + self.act = nn.gelu + elif activation == "silu": + self.act = nn.silu + else: + raise ValueError(f"Unknown activation: {activation}") + + def __call__(self, x): + if self.gated: + hidden_act = self.act(self.wi_0(x)) + hidden_linear = self.wi_1(x) + x = hidden_act * hidden_linear + else: + x = self.act(self.wi(x)) + return self.wo(x) + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.attention = MultiHeadAttention(config) + self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dense = DenseActivation(config) + + def __call__(self, x, mask): + y = self.ln1(x) + y, _ = self.attention(y, y, y, mask=mask) + x = x + y + + y = self.ln2(x) + y = self.dense(y) + return x + y + + +class TransformerEncoder(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.layers = [ + TransformerEncoderLayer(config) for i in range(config.num_layers) + ] + self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None): + pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])[None] + if mask is not None: + mask = mask[:, None, None, :] + pos_bias += mask + for layer in self.layers: + x = layer(x, mask=pos_bias) + return self.ln(x) + + +class TransformerDecoderLayer(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.self_attention = MultiHeadAttention(config) + self.cross_attention = MultiHeadAttention(config) + self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dense = DenseActivation(config) + + def __call__( + self, + x: mx.array, + memory: mx.array, + mask: mx.array, + memory_mask: mx.array, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ): + y = self.ln1(x) + y, cache = self.self_attention(y, y, y, mask, cache) + x = x + y + + y = self.ln2(x) + y, _ = self.cross_attention(y, memory, memory, memory_mask) + x = x + y + + y = self.ln3(x) + y = self.dense(y) + x = x + y + + return x, cache + + +class TransformerDecoder(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + n_layers = getattr(config, "num_decoder_layers", config.num_layers) + self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)] + self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) + + def __call__(self, x, memory, mask, memory_mask, cache=None): + if cache is not None: + offset = cache[0][0].shape[3] + else: + offset = 0 + cache = [None] * len(self.layers) + + T = offset + x.shape[1] + pos_bias = self.relative_attention_bias(T, T, offset=offset) + if mask is not None: + mask += pos_bias + else: + mask = pos_bias + + for e, layer in enumerate(self.layers): + x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e]) + x = self.ln(x) + + return x, cache + + +class OutputHead(nn.Module): + def __init__(self, config: T5Config): + self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) + + def __call__(self, inputs): + return self.linear(inputs) + + +def apply_top_p(logits: mx.array, top_p: float, min_tokens_to_keep=1): + assert min_tokens_to_keep <= logits.shape[-1] + logits_dtype = logits.dtype + # FIXME: The following is needed because mlx doesn't have the cumsum + # kernel for bfloat16. Once that is supported natively, this casting + # should be removed. @abdulfatir + logits = logits.astype(mx.float32) + sorted_indices = mx.argsort(logits, axis=-1) + sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1) + cumulative_probs = mx.softmax(sorted_logits, axis=-1).cumsum(axis=-1, reverse=True) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., -min_tokens_to_keep:] = False + masked_sorted_logits = mx.where(sorted_indices_to_remove, -mx.inf, sorted_logits) + unsorted_indices = mx.argsort(sorted_indices, axis=-1) + return mx.take_along_axis(masked_sorted_logits, unsorted_indices, axis=-1).astype( + logits_dtype + ) + + +def sample(logits, top_k=1, top_p=1.0, temperature=1.0): + vocab_size = logits.shape[-1] + assert top_p <= 1.0, f"{top_p=} should be <= 1.0" + + if temperature == 0 or top_k == 1: + return mx.argmax(logits, axis=-1) + else: + # Apply temperature term + if temperature != 1.0: + logits /= temperature + + # Apply top_k + if top_k >= vocab_size: + return mx.random.categorical( + apply_top_p(logits, top_p=top_p) if top_p < 1.0 else logits + ) + + top_k_indices = mx.argpartition(logits, top_k, axis=-1)[..., -top_k:] + top_k_logits = mx.take_along_axis(logits, top_k_indices, axis=-1) + + # Apply top_p + if top_p < 1.0: + top_k_logits = apply_top_p(top_k_logits, top_p=top_p) + + return top_k_indices[ + mx.arange(top_k_indices.shape[0]), mx.random.categorical(top_k_logits) + ] + + +class T5(nn.Module): + def __init__(self, config: T5Config): + self.wte = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = TransformerEncoder(config) + self.decoder = TransformerDecoder(config) + self.tie_word_embeddings = config.tie_word_embeddings + if not self.tie_word_embeddings: + self.lm_head = OutputHead(config) + self.model_dim = config.d_model + + def encode(self, inputs: mx.array, mask: mx.array): + return self.encoder(self.wte(inputs), mask) + + def decode( + self, + inputs: mx.array, + memory: mx.array, + memory_mask: mx.array, + cache=None, + ): + inputs = self.wte(inputs) + T = inputs.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) + mask = mask.astype(inputs.dtype) + else: + mask = None + + memory_mask = memory_mask[:, None, None, :] + y, cache = self.decoder( + inputs, memory=memory, mask=mask, memory_mask=memory_mask, cache=cache + ) + if not self.tie_word_embeddings: + y = self.lm_head(y) + else: + y *= self.model_dim**-0.5 + y = y @ self.wte.weight.T + return y, cache + + def __call__( + self, + inputs: mx.array, + mask: mx.array, + decoder_inputs: mx.array, + ): + memory = self.encode(inputs, mask=mask) + return self.decode(decoder_inputs, memory=memory, memory_mask=mask)[0] + + def generate( + self, + input_ids: mx.array, + attention_mask: mx.array, + min_new_tokens: Optional[int] = None, + max_new_tokens: int = 64, + do_sample: bool = True, + num_return_sequences: int = 1, + pad_token_id: int = 0, + eos_token_id: Optional[int] = None, + temperature: Optional[float] = 1.0, + top_k: int = 50, + top_p: float = 1.0, + ): + self.eval() + + def should_stop(current_token, num_sampled_tokens): + if eos_token_id is not None and (current_token == eos_token_id).all(): + return True + if num_sampled_tokens >= max_new_tokens: + return True + return False + + top_k = top_k if do_sample else 1 + attention_mask = (1.0 - attention_mask.astype(mx.float32)) * -1e9 + memory = self.encode(input_ids, mask=attention_mask) + + repeated_memory = mx.repeat(memory, num_return_sequences, axis=0) + repeated_attention_mask = mx.repeat( + attention_mask, num_return_sequences, axis=0 + ) + decoder_start_id = pad_token_id + decoder_inputs = mx.array([decoder_start_id] * len(repeated_attention_mask))[ + :, None + ] + + cache = None + prediction = [decoder_inputs] + num_sampled_tokens = 0 + while not should_stop(prediction[-1], num_sampled_tokens): + logits, cache = self.decode( + prediction[-1], + repeated_memory, + memory_mask=repeated_attention_mask, + cache=cache, + ) + if ( + min_new_tokens is not None + and eos_token_id is not None + and num_sampled_tokens < min_new_tokens + ): + logits[..., eos_token_id] = -float("inf") + + y = sample( + logits[:, -1, :], top_k=top_k, top_p=top_p, temperature=temperature + ) + num_sampled_tokens += 1 + prediction.append(y[:, None]) + + return mx.concatenate(prediction, axis=-1) diff --git a/src/chronos_mlx/translate.py b/src/chronos_mlx/translate.py new file mode 100644 index 0000000..857b5dc --- /dev/null +++ b/src/chronos_mlx/translate.py @@ -0,0 +1,77 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from ml-explore/mlx-examples: +# https://github.com/ml-explore/mlx-examples/blob/b8a348c1b8df4433cfacb9adbeb89b8aa3979ab2/t5/convert.py + +from pathlib import Path +from typing import Union + +import mlx.core as mx +import torch +from transformers import T5ForConditionalGeneration + +SHARED_REPLACEMENT_PATTERNS = [ + (".block.", ".layers."), + (".k.", ".key_proj."), + (".o.", ".out_proj."), + (".q.", ".query_proj."), + (".v.", ".value_proj."), + ("shared.", "wte."), + ("lm_head.", "lm_head.linear."), + (".layer.0.layer_norm.", ".ln1."), + (".layer.1.layer_norm.", ".ln2."), + (".layer.2.layer_norm.", ".ln3."), + (".final_layer_norm.", ".ln."), + ( + "layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings.", + ), +] + +ENCODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".attention."), + (".layer.1.DenseReluDense.", ".dense."), +] + +DECODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".self_attention."), + (".layer.1.EncDecAttention.", ".cross_attention."), + (".layer.2.DenseReluDense.", ".dense."), +] + + +def replace_key(key: str) -> str: + for old, new in SHARED_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + if key.startswith("encoder."): + for old, new in ENCODER_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + elif key.startswith("decoder."): + for old, new in DECODER_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + return key + + +def translate_weights(model_name_or_path: Union[str, Path], dtype: mx.Dtype): + """Translate a HuggingFace transformers T5 model to MLX. + + Parameters + ---------- + model_name + HuggingFace model name or local path. + dtype + mlx dtype for the resulting mlx weights. + + Returns + ------- + A state dictionary with weights as mlx arrays. + """ + model = T5ForConditionalGeneration.from_pretrained( + model_name_or_path, torch_dtype=torch.float32 + ) + weights = { + replace_key(k): mx.array(v.numpy(), dtype=dtype) + for k, v in model.state_dict().items() + } + return weights diff --git a/test/test_chronos.py b/test/test_chronos.py deleted file mode 100644 index a7d63bc..0000000 --- a/test/test_chronos.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -from pathlib import Path -from typing import Tuple - -import torch -import pytest - -from chronos import ChronosConfig, ChronosPipeline - - -@pytest.mark.xfail -@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27]) -@pytest.mark.parametrize("n_special_tokens", [2, 5, 13]) -@pytest.mark.parametrize("use_eos_token", [False, True]) -def test_tokenizer_fixed_data( - n_numerical_tokens: int, n_special_tokens: int, use_eos_token: bool -): - n_tokens = n_numerical_tokens + n_special_tokens - context_length = 3 - - config = ChronosConfig( - tokenizer_class="MeanScaleUniformBins", - tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0), - n_tokens=n_tokens, - n_special_tokens=n_special_tokens, - pad_token_id=0, - eos_token_id=1, - use_eos_token=use_eos_token, - model_type="seq2seq", - context_length=512, - prediction_length=64, - num_samples=20, - temperature=1.0, - top_k=50, - top_p=1.0, - ) - - tokenizer = config.create_tokenizer() - - context = torch.tensor( - [ - [-3.7, 3.7], - [-42.0, 42.0], - ] - ) - batch_size, _ = context.shape - - token_ids, attention_mask, scale = tokenizer.input_transform(context) - - assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token) - assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size)) - assert all(token_ids[:, 1] == torch.tensor([n_special_tokens]).repeat(batch_size)) - assert all(token_ids[:, 2] == torch.tensor([n_tokens - 1]).repeat(batch_size)) - - if use_eos_token: - assert all(token_ids[:, 3] == torch.tensor([1]).repeat(batch_size)) - - samples = tokenizer.output_transform( - torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1), - tokenizer_state=scale, - ) - - assert (samples[:, 0, [0, -1]] == context).all() - - -@pytest.mark.xfail -@pytest.mark.parametrize("use_eos_token", [False, True]) -def test_tokenizer_random_data(use_eos_token: bool): - context_length = 8 - n_tokens = 256 - n_special_tokens = 2 - - config = ChronosConfig( - tokenizer_class="MeanScaleUniformBins", - tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0), - n_tokens=n_tokens, - n_special_tokens=n_special_tokens, - pad_token_id=0, - eos_token_id=1, - use_eos_token=use_eos_token, - model_type="seq2seq", - context_length=context_length, - prediction_length=64, - num_samples=20, - temperature=1.0, - top_k=50, - top_p=1.0, - ) - - tokenizer = config.create_tokenizer() - - context = torch.tensor( - [ - [torch.nan, torch.nan, 1.0, 1.1, torch.nan, 2.0], - [3.0, torch.nan, 3.9, 4.0, 4.1, 4.9], - ] - ) - - token_ids, attention_mask, scale = tokenizer.input_transform(context) - - assert token_ids.shape == ( - *context.shape[:-1], - context_length + 1 * use_eos_token, - ) - assert attention_mask.shape == ( - *context.shape[:-1], - context_length + 1 * use_eos_token, - ) - assert scale.shape == context.shape[:1] - - sample_ids = torch.randint(low=n_special_tokens, high=n_tokens, size=(2, 10, 4)) - sample_ids[0, 0, 0] = n_special_tokens - sample_ids[-1, -1, -1] = n_tokens - 1 - - samples = tokenizer.output_transform(sample_ids, scale) - - assert samples.shape == (2, 10, 4) - - -def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None: - assert isinstance(samples, torch.Tensor) - assert samples.shape == shape - - -@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) -def test_pipeline_predict(torch_dtype: str): - pipeline = ChronosPipeline.from_pretrained( - Path(__file__).parent / "dummy-chronos-model", - device_map="cpu", - torch_dtype=torch_dtype, - ) - context = 10 * torch.rand(size=(4, 16)) + 10 - - # input: tensor of shape (batch_size, context_length) - - samples = pipeline.predict(context, num_samples=12, prediction_length=3) - validate_tensor(samples, (4, 12, 3)) - - with pytest.raises(ValueError): - samples = pipeline.predict(context, num_samples=7, prediction_length=65) - - samples = pipeline.predict( - context, num_samples=7, prediction_length=65, limit_prediction_length=False - ) - validate_tensor(samples, (4, 7, 65)) - - # input: batch_size-long list of tensors of shape (context_length,) - - samples = pipeline.predict(list(context), num_samples=12, prediction_length=3) - validate_tensor(samples, (4, 12, 3)) - - with pytest.raises(ValueError): - samples = pipeline.predict(list(context), num_samples=7, prediction_length=65) - - samples = pipeline.predict( - list(context), - num_samples=7, - prediction_length=65, - limit_prediction_length=False, - ) - validate_tensor(samples, (4, 7, 65)) - - # input: tensor of shape (context_length,) - - samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3) - validate_tensor(samples, (1, 12, 3)) - - with pytest.raises(ValueError): - samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65) - - samples = pipeline.predict( - context[0, ...], - num_samples=7, - prediction_length=65, - limit_prediction_length=False, - ) - validate_tensor(samples, (1, 7, 65)) - - -@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) -def test_pipeline_embed(torch_dtype: str): - pipeline = ChronosPipeline.from_pretrained( - Path(__file__).parent / "dummy-chronos-model", - device_map="cpu", - torch_dtype=torch_dtype, - ) - d_model = pipeline.model.model.config.d_model - context = 10 * torch.rand(size=(4, 16)) + 10 - expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0) - - # input: tensor of shape (batch_size, context_length) - - embedding, scale = pipeline.embed(context) - validate_tensor(embedding, (4, expected_embed_length, d_model)) - validate_tensor(scale, (4,)) - - # input: batch_size-long list of tensors of shape (context_length,) - - embedding, scale = pipeline.embed(list(context)) - validate_tensor(embedding, (4, expected_embed_length, d_model)) - validate_tensor(scale, (4,)) - - # input: tensor of shape (context_length,) - embedding, scale = pipeline.embed(context[0, ...]) - validate_tensor(embedding, (1, expected_embed_length, d_model)) - validate_tensor(scale, (1,)) diff --git a/test/test_chronos_mlx.py b/test/test_chronos_mlx.py new file mode 100644 index 0000000..4f718a5 --- /dev/null +++ b/test/test_chronos_mlx.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import Tuple + +import mlx.core as mx +import numpy as np +import pytest + +from chronos_mlx.t5 import apply_top_p +from chronos_mlx import ChronosPipeline + + +def validate_array(samples: np.ndarray, shape: Tuple[int, ...]) -> None: + assert isinstance(samples, np.ndarray) + assert samples.shape == shape + + +@pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) +def test_pipeline_predict(dtype: str): + pipeline = ChronosPipeline.from_pretrained( + Path(__file__).parent / "dummy-chronos-model", + dtype=dtype, + ) + context = 10 * np.random.rand(4, 16) + 10 + + # input: tensor of shape (batch_size, context_length) + + samples = pipeline.predict(context, num_samples=12, prediction_length=3) + validate_array(samples, (4, 12, 3)) + + with pytest.raises(ValueError): + samples = pipeline.predict(context, num_samples=7, prediction_length=65) + + samples = pipeline.predict( + context, num_samples=7, prediction_length=65, limit_prediction_length=False + ) + validate_array(samples, (4, 7, 65)) + + # input: batch_size-long list of tensors of shape (context_length,) + + samples = pipeline.predict(list(context), num_samples=12, prediction_length=3) + validate_array(samples, (4, 12, 3)) + + with pytest.raises(ValueError): + samples = pipeline.predict(list(context), num_samples=7, prediction_length=65) + + samples = pipeline.predict( + list(context), + num_samples=7, + prediction_length=65, + limit_prediction_length=False, + ) + validate_array(samples, (4, 7, 65)) + + # input: tensor of shape (context_length,) + + samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3) + validate_array(samples, (1, 12, 3)) + + with pytest.raises(ValueError): + samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65) + + samples = pipeline.predict( + context[0, ...], + num_samples=7, + prediction_length=65, + limit_prediction_length=False, + ) + validate_array(samples, (1, 7, 65)) + + +@pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) +def test_pipeline_embed(dtype: str): + pipeline = ChronosPipeline.from_pretrained( + Path(__file__).parent / "dummy-chronos-model", + dtype=dtype, + ) + d_model = pipeline.model.model.model_dim + context = 10 * np.random.rand(4, 16) + 10 + expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0) + + # input: tensor of shape (batch_size, context_length) + + embedding, scale = pipeline.embed(context) + validate_array(embedding, (4, expected_embed_length, d_model)) + validate_array(scale, (4,)) + + # input: batch_size-long list of tensors of shape (context_length,) + + embedding, scale = pipeline.embed(list(context)) + validate_array(embedding, (4, expected_embed_length, d_model)) + validate_array(scale, (4,)) + + # input: tensor of shape (context_length,) + embedding, scale = pipeline.embed(context[0, ...]) + validate_array(embedding, (1, expected_embed_length, d_model)) + validate_array(scale, (1,)) + + +@pytest.mark.parametrize( + "top_p,expected_non_zero_probs", + [ + ( + 0.1, + mx.array( + [ + [False, True, False, False], + [False, True, False, False], + [True, False, False, False], + [True, False, False, False], + [False, False, False, True], + ] + ), + ), + ( + 0.5, + mx.array( + [ + [False, True, False, False], + [False, True, False, False], + [True, False, False, False], + [True, False, False, False], + [False, False, True, True], + ] + ), + ), + ( + 0.95, + mx.array( + [ + [False, True, True, True], + [False, True, False, True], + [True, False, False, False], + [True, True, False, False], + [False, True, True, True], + ] + ), + ), + ], +) +def test_apply_top_p(top_p: float, expected_non_zero_probs: mx.array): + probs = mx.array( + [ + [0.1, 0.4, 0.3, 0.2], + [0.01, 0.39, 0.25, 0.35], + [0.9, 0.01, 0.01, 0.08], + [0.7, 0.2, 0.05, 0.05], + [0.25, 0.25, 0.25, 0.25], + ], + ) + top_p_probs = mx.softmax(apply_top_p(probs.log(), top_p=top_p), axis=-1) + assert mx.all(mx.not_equal(top_p_probs, 0.0) == expected_non_zero_probs)