Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify pretraining README snippet #160

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
- Convert your time series dataset into a GluonTS-compatible file dataset. We recommend using the arrow format. You may use the `convert_to_arrow` function from the following snippet for that. Optionally, you may use [synthetic data from KernelSynth](#generating-synthetic-time-series-kernelsynth) to follow along.
```py
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Union

import numpy as np
from gluonts.dataset.arrow import ArrowWriter
Expand All @@ -35,18 +35,26 @@
def convert_to_arrow(
path: Union[str, Path],
time_series: Union[List[np.ndarray], np.ndarray],
start_times: Optional[Union[List[np.datetime64], np.ndarray]] = None,
compression: str = "lz4",
):
if start_times is None:
# Set an arbitrary start time
start_times = [np.datetime64("2000-01-01 00:00", "s")] * len(time_series)
"""
Store a given set of series into Arrow format at the specified path.

Input data can be either a list of 1D numpy arrays, or a single 2D
numpy array of shape (num_series, time_length).
"""
assert isinstance(time_series, list) or (
isinstance(time_series, np.ndarray) and
time_series.ndim == 2
)

assert len(time_series) == len(start_times)
# Set an arbitrary start time
start = np.datetime64("2000-01-01 00:00", "s")

dataset = [
{"start": start, "target": ts} for ts, start in zip(time_series, start_times)
{"start": start, "target": ts} for ts in time_series
]

ArrowWriter(compression=compression).write_to_file(
dataset,
path=path,
Expand All @@ -59,7 +67,6 @@

# Convert to GluonTS arrow format
convert_to_arrow("./noise-data.arrow", time_series=time_series)

```
- Modify the [training configs](training/configs) to use your data. Let's use the KernelSynth data as an example.
```yaml
Expand Down