Skip to content

Commit

Permalink
Merge pull request #145 from google-research/rajat_dev
Browse files Browse the repository at this point in the history
Full pytorch support
  • Loading branch information
rajatsen91 authored Sep 13, 2024
2 parents 16d84e9 + 4b62630 commit d065f8c
Show file tree
Hide file tree
Showing 19 changed files with 2,409 additions and 1,938 deletions.
109 changes: 50 additions & 59 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ This is not an officially supported Google product.

We recommend at least 16GB RAM to load TimesFM dependencies.

## Update - Aug. 6, 2024

## Update - Sep. 12, 2024
- We have released full pytorch support (excluding PEFT parts).
- Shoutout to @tanmayshishodia for checking in PEFT methods like LoRA and DoRA.
- To install TimesFM, you can now simply do: `pip install timesfm`.
- Launched [finetuning support](https://github.com/google-research/timesfm/blob/master/notebooks/finetuning.ipynb) that lets you finetune the weights of the pretrained TimesFM model on your own data.
- Launched [~zero-shot covariate support](https://github.com/google-research/timesfm/blob/master/notebooks/covariates.ipynb) with external regressors. More details [here](https://github.com/google-research/timesfm?tab=readme-ov-file#covariates-support).

## Checkpoint timesfm-1.0-200m
## Checkpoint timesfm-1.0-200m (-pytorch)

timesfm-1.0-200m is the first open model checkpoint:

Expand All @@ -39,68 +39,55 @@ Please look into the README files in the respective benchmark directories within

## Installation

### Installation as a package

To install the TimesFM as a package, you can run the following command without cloning this repo:

`pip install timesfm`

### Installation using conda
### Local installation using poetry

For calling TimesFM, We have two environment files. Inside `timesfm`, for
GPU installation (assuming CUDA 12 has been setup), you can create a conda
environment `tfm_env` from the base folder through:
We will be using `pyenv` and `poetry`. In order to set these things up please follow the instructions [here](https://substack.com/home/post/p-148747960?r=28a5lx&utm_campaign=post&utm_medium=web). Note that the PAX (or JAX) version needs to run on python 3.10.x and the PyTorch version can run on >=3.11.x. Therefore make sure you have two versions of python installed:

```
conda env create --file=environment.yml
pyenv install 3.10
pyenv install 3.11
pyenv versions # to list the versions available (lets assume the versions are 3.10.15 and 3.11.10)
```

For a CPU setup please use,
### For PAX version installation do the following.

```
conda env create --file=environment_cpu.yml
pyenv local 3.10.15
poetry env use 3.10.15
poetry lock
poetry install --only pax
```
to create the environment instead.

Follow by
After than you can run the timesfm under `poetry shell` or do `poetry run python3 ...`.

### For PyTorch version installation do the following.

```
conda activate tfm_env
pip install -e .
pyenv local 3.11.10
poetry env use 3.11.10
poetry lock
poetry install --only torch
```
to install the package.

After than you can run the timesfm under `poetry shell` or do `poetry run python3 ...`.

**Note**:

1. Running the provided benchmarks would require additional dependencies.
Please use the environment files under `experiments` instead.
Please see the `experiments` section fro more instructions.

2. The dependency `lingvo` does not support ARM architectures, and the code is not working for machines with Apple silicon. We are aware of this issue and are working on a solution. Stay tuned.


### Local installation using poetry

To from the current repository/local version (like you would have previously done with `pip -e .`), you can run the command

```
pip install poetry # optional
poetry install
```

This will install the environment in the local .venv folder (depends on the configuration) and matches the python command to the poetry environment. If this is not the case, you can use `poetry run python` to use the local environment.

### Notes

1. Running the provided benchmarks would require additional dependencies.
Please use the environment files under `experiments` instead.
1. Running the provided benchmarks would require additional dependencies. Please see the `experiments` folder.

2. The dependency `lingvo` does not support ARM architectures, and the code is not working for machines with Apple silicon. We are aware of this issue and are working on a solution. Stay tuned.
2. The dependency `lingvo` does not support ARM architectures, and the PAX version is not working for machines with Apple silicon.

#### Building the package and publishing to PyPI
### Install from PyPI (and publish)

The package can be built using the command `poetry build`.
Instructions coming soon.

To build and publish it to PyPI, the command `poetry publish` can be used. This command will require the user to have the necessary permissions to publish to the PyPI repository.

## Usage

Expand All @@ -110,32 +97,36 @@ Then the base class can be loaded as,
```python
import timesfm

# For PAX
tfm = timesfm.TimesFm(
context_len=<context>,
horizon_len=<horizon>,
input_patch_len=32,
output_patch_len=128,
num_layers=20,
model_dims=1280,
backend=<backend>,
)
tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")
hparams=timesfm.TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-1.0-200m"),
)

# For Torch
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-1.0-200m-pytorch"),
)
```

Note that the four parameters are fixed to load the 200m model

```python
input_patch_len=32,
output_patch_len=128,
num_layers=20,
model_dims=1280,
```
Note some of the parameters are fixed to load the 200m model

1. The `context_len` here can be set as the max context length **of the model**. **It needs to be a multiplier of `input_patch_len`, i.e. a multiplier of 32.** You can provide a shorter series to the `tfm.forecast()` function and the model will handle it. Currently, the model handles a max context length of 512, which can be increased in later releases. The input time series can have **any context length**. Padding / truncation will be handled by the inference code if needed.
1. The `context_len` in `hparams` here can be set as the max context length **of the model**. **It needs to be a multiplier of `input_patch_len`, i.e. a multiplier of 32.** You can provide a shorter series to the `tfm.forecast()` function and the model will handle it. Currently, the model handles a max context length of 512, which can be increased in later releases. The input time series can have **any context length**. Padding / truncation will be handled by the inference code if needed.

2. The horizon length can be set to anything. We recommend setting it to the largest horizon length you would need in the forecasting tasks for your application. We generally recommend horizon length <= context length but it is not a requirement in the function call.

3. `backend` is one of "cpu", "gpu" or "tpu", case sensitive.
3. `backend` is one of "cpu", "gpu", case sensitive.

### Perform inference

Expand Down
21 changes: 0 additions & 21 deletions environment.yml

This file was deleted.

21 changes: 0 additions & 21 deletions environment_cpu.yml

This file was deleted.

28 changes: 0 additions & 28 deletions experiments/environment.yml

This file was deleted.

28 changes: 0 additions & 28 deletions experiments/environment_cpu.yml

This file was deleted.

16 changes: 11 additions & 5 deletions experiments/extended_benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@ The benchmark setting has been borrowed from Nixtla's original [benchmarking](ht

## Running TimesFM on the benchmark

Install the environment and the package as detailed in the main README and then follow the steps from the base directory.
We need to add the following packages for running these benchmarks. Follow the installation instructions till before `poetry lock`. Then,

```
conda activate tfm_env
TF_CPP_MIN_LOG_LEVEL=2 XLA_PYTHON_CLIENT_PREALLOCATE=false python3 -m experiments.extended_benchmarks.run_timesfm --model_path=<model_path> --backend="gpu"
poetry add git+https://github.com/awslabs/gluon-ts.git
poetry lock
poetry install --only <pax or pytorch>
```

To run the timesfm on the benchmark do:

```
poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m(-pytorch) --backend="gpu"
```

In the above, `<model_path>` should point to the checkpoint directory that can be downloaded from HuggingFace.

Note: In the current version of TimesFM we focus on point forecasts and therefore the mase, smape have been calculated using the quantile head corresponding to the median i.e 0.5 quantile. We do offer 10 quantile heads but they have not been calibrated after pretraining. We recommend using them with caution or calibrate/conformalize them on a hold out for your applications. More to follow on later versions.

Expand All @@ -22,7 +28,7 @@ Note: In the current version of TimesFM we focus on point forecasts and therefor

__Update:__ We have added TimeGPT-1 to the benchmark results. We had to remove the Dominick dataset as we were not able to run TimeGPT-1 on this benchmark. Note that the previous results including Dominick remain available at `./tfm_results.png`. In order to reproduce the results for TimeGPT-1, please run `run_timegpt.py`.

_Remark:_ All baselines except the ones involving TimeGPT were run performed on a [g2-standard-32](https://cloud.google.com/compute/docs/gpus). Since TimeGPT-1 can only be accessed by an API, the time column might not reflect the true speed of the model as it also includes the communication cost. Moreover, we are not sure about the exact backend hardware for TimeGPT.
_Remark:_ All baselines except the ones involving TimeGPT were run performed on a [g2-standard-32](https://cloud.google.com/compute/docs/gpus). Since TimeGPT-1 can only be accessed by an API, the time column might not reflect the true speed of the model as it also includes the communication cost. Moreover, we are not sure about the exact backend hardware for TimeGPT. The TimesFM latency numbers are from the PAX version.

We can see that TimesFM performs the best in terms of both mase and smape. More importantly it is much faster than the other methods, in particular it is more than 600x faster than StatisticalEnsemble and 80x faster than Chronos (Large).

Expand Down
35 changes: 12 additions & 23 deletions experiments/extended_benchmarks/run_timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluation script for timesfm."""

import os
Expand All @@ -21,12 +20,10 @@
from absl import flags
import numpy as np
import pandas as pd
from paxml import checkpoints
import timesfm

from .utils import ExperimentHandler


dataset_names = [
"m1_monthly",
"m1_quarterly",
Expand Down Expand Up @@ -74,35 +71,27 @@
"m4_yearly": 64,
}

_MODEL_PATH = flags.DEFINE_string(
"model_path", "/home/timesfm_q10_20240501", "Path to model"
)
_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-1.0-200m",
"Path to model")
_BATCH_SIZE = flags.DEFINE_integer("batch_size", 64, "Batch size")
_HORIZON = flags.DEFINE_integer("horizon", 128, "Horizon")
_BACKEND = flags.DEFINE_string("backend", "gpu", "Backend")
_NUM_JOBS = flags.DEFINE_integer("num_jobs", 1, "Number of jobs")
_SAVE_DIR = flags.DEFINE_string("save_dir", "./results", "Save directory")


QUANTILES = list(np.arange(1, 10) / 10.0)


def main():
results_list = []
tfm = timesfm.TimesFm(
context_len=512,
horizon_len=_HORIZON.value,
input_patch_len=32,
output_patch_len=128,
num_layers=20,
model_dims=1280,
backend=_BACKEND.value,
per_core_batch_size=_BATCH_SIZE.value,
quantiles=QUANTILES,
)
tfm.load_from_checkpoint(
_MODEL_PATH.value,
checkpoint_type=checkpoints.CheckpointType.FLAX,
hparams=timesfm.TimesFmHparams(
backend=_BACKEND.value,
per_core_batch_size=_BATCH_SIZE.value,
horizon_len=_HORIZON.value,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id=_MODEL_PATH.value),
)
run_id = np.random.randint(100000)
model_name = "timesfm"
Expand All @@ -127,9 +116,9 @@ def main():
)
total_time = time.time() - init_time
time_df = pd.DataFrame({"time": [total_time], "model": model_name})
results = exp.evaluate_from_predictions(
models=[model_name], fcsts_df=fcsts_df, times_df=time_df
)
results = exp.evaluate_from_predictions(models=[model_name],
fcsts_df=fcsts_df,
times_df=time_df)
print(results, flush=True)
results_list.append(results)
results_full = pd.concat(results_list)
Expand Down
Loading

0 comments on commit d065f8c

Please sign in to comment.