Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KernelSynth script #64

Merged
merged 4 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## 🚀 News

- **10 May 2024**: 🚀 We added the code for pretraining and fine-tuning Chronos models. You can find it in [this folder](./scripts/training).
- **10 May 2024**: 🚀 We added the code for pretraining and fine-tuning Chronos models. You can find it in [this folder](./scripts/training). We also added [a script](./scripts/kernel-synth.py) for generating synthetic time series data from Gaussian processes (KernelSynth; see Section 4.2 in the paper for details).
- **19 Apr 2024**: 🚀 Chronos is now supported on [AutoGluon-TimeSeries](https://auto.gluon.ai/stable/tutorials/timeseries/index.html), the powerful AutoML package for time series forecasting which enables model ensembles, cloud deployments, and much more. Get started with the [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).
- **08 Apr 2024**: 🧪 Experimental [MLX inference support](https://github.com/amazon-science/chronos-forecasting/tree/mlx) added. If you have an Apple Silicon Mac, you can now obtain significantly faster forecasts from Chronos compared to CPU inference. This provides an alternative way to exploit the GPU on your Apple Silicon Macs together with the "mps" support in PyTorch.
- **25 Mar 2024**: [v1.1.0 released](https://github.com/amazon-science/chronos-forecasting/releases/tag/v1.1.0) with inference optimizations and `pipeline.embed` to extract encoder embeddings from Chronos.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
test = ["pytest~=8.0", "numpy~=1.21"]
typecheck = ["mypy~=1.9"]
training = ["gluonts[pro]", "numpy", "tensorboard", "typer", "typer-config"]
kernel-synth = ["gluonts[pro]", "joblib", "scikit-learn"]

[tool.mypy]
ignore_missing_imports = true
197 changes: 197 additions & 0 deletions scripts/kernel-synth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import argparse
import functools
from pathlib import Path
from typing import Optional

import numpy as np
from gluonts.dataset.arrow import ArrowWriter
from joblib import Parallel, delayed
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import (
RBF,
ConstantKernel,
DotProduct,
ExpSineSquared,
Kernel,
RationalQuadratic,
WhiteKernel,
)
from tqdm.auto import tqdm

LENGTH = 1024
KERNEL_BANK = [
ExpSineSquared(periodicity=24 / LENGTH), # H
ExpSineSquared(periodicity=48 / LENGTH), # 0.5H
ExpSineSquared(periodicity=96 / LENGTH), # 0.25H
ExpSineSquared(periodicity=24 * 7 / LENGTH), # H
ExpSineSquared(periodicity=48 * 7 / LENGTH), # 0.5H
ExpSineSquared(periodicity=96 * 7 / LENGTH), # 0.25H
ExpSineSquared(periodicity=7 / LENGTH), # D
ExpSineSquared(periodicity=14 / LENGTH), # 0.5D
ExpSineSquared(periodicity=30 / LENGTH), # D
ExpSineSquared(periodicity=60 / LENGTH), # 0.5D
ExpSineSquared(periodicity=365 / LENGTH), # D
ExpSineSquared(periodicity=365 * 2 / LENGTH), # 0.5D
ExpSineSquared(periodicity=4 / LENGTH), # W
ExpSineSquared(periodicity=26 / LENGTH), # W
ExpSineSquared(periodicity=52 / LENGTH), # W
ExpSineSquared(periodicity=4 / LENGTH), # M
ExpSineSquared(periodicity=6 / LENGTH), # M
ExpSineSquared(periodicity=12 / LENGTH), # M
ExpSineSquared(periodicity=4 / LENGTH), # Q
ExpSineSquared(periodicity=4 * 10 / LENGTH), # Q
ExpSineSquared(periodicity=10 / LENGTH), # Y
DotProduct(sigma_0=0.0),
DotProduct(sigma_0=1.0),
DotProduct(sigma_0=10.0),
RBF(length_scale=0.1),
RBF(length_scale=1.0),
RBF(length_scale=10.0),
RationalQuadratic(alpha=0.1),
RationalQuadratic(alpha=1.0),
RationalQuadratic(alpha=10.0),
WhiteKernel(noise_level=0.1),
WhiteKernel(noise_level=1.0),
ConstantKernel(),
]


def random_binary_map(a: Kernel, b: Kernel):
"""
Applies a random binary operator (+ or *) with equal probability
on kernels ``a`` and ``b``.

Parameters
----------
a
A GP kernel.
b
A GP kernel.

Returns
-------
The composite kernel `a + b` or `a * b`.
"""
binary_maps = [lambda x, y: x + y, lambda x, y: x * y]
return np.random.choice(binary_maps)(a, b)


def sample_from_gp_prior(
kernel: Kernel, X: np.ndarray, random_seed: Optional[int] = None
):
"""
Draw a sample from a GP prior.

Parameters
----------
kernel
The GP covaraince kernel.
X
The input "time" points.
random_seed, optional
The random seed for sampling, by default None.

Returns
-------
A time series sampled from the GP prior.
"""
if X.ndim == 1:
X = X[:, None]

assert X.ndim == 2
gpr = GaussianProcessRegressor(kernel=kernel)
ts = gpr.sample_y(X, n_samples=1, random_state=random_seed)

return ts


def sample_from_gp_prior_efficient(
kernel: Kernel,
X: np.ndarray,
random_seed: Optional[int] = None,
method: str = "eigh",
):
"""
Draw a sample from a GP prior. An efficient version that allows specification
of the sampling method. The default sampling method used in GaussianProcessRegressor
is based on SVD which is significantly slower that alternatives such as `eigh` and
`cholesky`.

Parameters
----------
kernel
The GP covaraince kernel.
X
The input "time" points.
random_seed, optional
The random seed for sampling, by default None.
method, optional
The sampling method for multivariate_normal, by default `eigh`.

Returns
-------
A time series sampled from the GP prior.
"""
if X.ndim == 1:
X = X[:, None]

assert X.ndim == 2

cov = kernel(X)
ts = np.random.default_rng(seed=random_seed).multivariate_normal(
mean=np.zeros(X.shape[0]), cov=cov, method=method
)

return ts


def generate_time_series(max_kernels: int = 5):
"""Generate a synthetic time series from KernelSynth.

Parameters
----------
max_kernels, optional
The maximum number of base kernels to use for each time series, by default 5

Returns
-------
A time series generated by KernelSynth.
"""
while True:
X = np.linspace(0, 1, LENGTH)

# Randomly select upto max_kernels kernels from the KERNEL_BANK
selected_kernels = np.random.choice(
KERNEL_BANK, np.random.randint(1, max_kernels + 1), replace=True
)

# Combine the sampled kernels using random binary operators
kernel = functools.reduce(random_binary_map, selected_kernels)

# Sample a time series from the GP prior
try:
ts = sample_from_gp_prior(kernel=kernel, X=X)
except np.linalg.LinAlgError as err:
print("Error caught:", err)
continue

# The timestamp is arbitrary
return {"start": np.datetime64("2000-01-01 00:00", "s"), "target": ts.squeeze()}


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-N", "--num-series", type=int, default=1000_000)
parser.add_argument("-J", "--max-kernels", type=int, default=5)
args = parser.parse_args()
path = Path(__file__).parent / "kernelsynth-data.arrow"

generated_dataset = Parallel(n_jobs=-1)(
delayed(generate_time_series)(max_kernels=args.max_kernels)
for _ in tqdm(range(args.num_series))
)

ArrowWriter(compression="lz4").write_to_file(
generated_dataset,
path=path,
)