Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TSMixer model #1807

Closed
alexcolpitts96 opened this issue May 31, 2023 · 12 comments
Closed

Add TSMixer model #1807

alexcolpitts96 opened this issue May 31, 2023 · 12 comments
Labels
new model Adding a new model

Comments

@alexcolpitts96
Copy link
Contributor

I recently found TSMixer (http://arxiv.org/abs/2303.06053).

It is very similar to TiDE (#1726) but with a few tweaks.

It should be pretty straight forward to implement based on the implementation of TiDE (#1727).

I will try to get started on it in the next few days.

@alexcolpitts96 alexcolpitts96 added the triage Issue waiting for triaging label May 31, 2023
@madtoinou madtoinou added new model Adding a new model and removed triage Issue waiting for triaging labels Jun 1, 2023
@alexcolpitts96
Copy link
Contributor Author

Google Research implementation: https://github.com/google-research/google-research/blob/master/tsmixer/tsmixer_basic/models/tsmixer.py

Details in the paper aren't great; however, the source code clears things up.

@joshua-xia
Copy link

@alexcolpitts96 did you get the paper implement tsmixer_extended? it seems support past/static/future covariable features

@alexcolpitts96
Copy link
Contributor Author

I have ran into a few things with the implementation and had some other PRs that I needed to cleanup.

I managed to implement reversible instance normalization, but there is a bug in the tests that only happens during the build process within Github.

The rest of the model is pretty straightforward, I just need to find the time. I just started a new job so I am a little short on time as of late.

@meteoDaniel
Copy link

meteoDaniel commented Sep 15, 2023

Recently Google published a paper and an article on TSMixer:
https://blog.research.google/2023/09/tsmixer-all-mlp-architecture-for-time.html

@alexcolpitts96 do you have started with a pytorch implementation that can fit into darts?

@alexcolpitts96
Copy link
Contributor Author

I started working on it roughly two months ago. I have been busy wrapping up school and starting a new job. I should have some time to clean it up over the next few weeks.

I managed to get the skeleton written, but I still need to add covariates and probabilistic forecasting.

https://github.com/alexcolpitts96/darts/blob/tsmixer/darts/models/forecasting/tsmixer_model.py

@meteoDaniel
Copy link

From my point of view that looks good.
Why do you think you need probabilistic forecasting? Does TSMixer provide it by nature?
Within tft , probabilistic forecast is a result of the quantile loss function. Maybe I am wrong but in case you want to add this feature to TSMixer, I think you just need to run it with QuantileLoss.

@thijsjls
Copy link

@alexcolpitts96 Did you have any time to work on this further? Would be interested in using this model. Also open to contribute.

@StatMixedML
Copy link

IBM has released its version of the PatchTSMixer on HuggingFace. Maybe this helps to have it available in darts soon

@candalfigomoro
Copy link

Pay attention to the fact that there are apparently 2 different models named "TSMixer":

@leoniewgnr
Copy link

leoniewgnr commented Mar 6, 2024

@alexcolpitts96 @meteoDaniel @thijsjls
Hi everyone, I've looked into your code @alexcolpitts96 and it looks really good! I've tried it, using the following code, including lists of timeseries, covariates, encoders:

model_params = {
        "input_chunk_length": 240,  # hist_len
        # not tuned
        "use_static_covariates": False,
        "output_chunk_length": 37,  # pred_len
        "n_epochs": n_epochs,
    }
    

    model = TSMixerModel(
        **model_params,
        pl_trainer_kwargs={
          "accelerator": "auto",
          "devices":"auto"
        },
        add_encoders = {
          'datetime_attribute': {'past': ['hour', 'day_of_week', 'month'],'future': ['hour', 'day_of_week', 'month']},
          'transformer': Scaler(),
        },
        model_name = 'tsmixer',
        save_checkpoints=True,
        force_reset=True
    )
    
    model.fit(ts_train_scaled_list,
              future_covariates=cov_list,
              val_series = ts_val_scaled_list,
              val_future_covariates = cov_list,
              verbose=False)
    
    #load best model on validation set to avoid overfitting
    model = TSMixerModel.load_from_checkpoint(model_name = 'tsmixer', best=True)

and it works great! Only thing I had to change in you code is still a old import statement from skicit-learn, which is removed from the current darts version, so just merging with the newest darts version, should resolve it.

I would really appreciate it if you go forward and push this as I really would like to use it and results are so good from TSMixer. Thank you very much!
I'm also very happy to help!

@cristof-r
Copy link
Contributor

I made a PR as the above seems to have gone stale. Any feedback is welcome!

@madtoinou
Copy link
Collaborator

Fixed by #2293

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new model Adding a new model
Projects
None yet
Development

No branches or pull requests

9 participants