From cf57f7788b0903910898f8cbfb980e2734fda374 Mon Sep 17 00:00:00 2001 From: Dennis Bader Date: Fri, 11 Aug 2023 09:28:36 +0200 Subject: [PATCH] move tensor to cpu before numpy (#1949) * move tensor to cpu before numpy * update changelog --- CHANGELOG.md | 1 + darts/explainability/tft_explainer.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6e8cffc94..db1802937b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Fixed** - Fixed a bug in `TimeSeries.from_dataframe()` when using a pandas.DataFrame with `df.columns.name != None`. [#1938](https://github.com/unit8co/darts/pull/1938) by [Antoine Madrona](https://github.com/madtoinou). +- Fixed a bug when using `TFTExplainer` with a `TFTModel` running on GPU. [#1949](https://github.com/unit8co/darts/pull/1949) by [Dennis Bader](https://github.com/dennisbader). ## [0.25.0](https://github.com/unit8co/darts/tree/0.25.0) (2023-08-04) diff --git a/darts/explainability/tft_explainer.py b/darts/explainability/tft_explainer.py index d44633d62b..4665638a49 100644 --- a/darts/explainability/tft_explainer.py +++ b/darts/explainability/tft_explainer.py @@ -207,7 +207,7 @@ def explain( # get the weights and the attention head from the trained model for the prediction # aggregate over attention heads attention_heads = ( - self.model.model._attn_out_weights.detach().numpy().sum(axis=-2) + self.model.model._attn_out_weights.detach().cpu().numpy().sum(axis=-2) ) # get the variable importances (pd.DataFrame with rows corresponding to the number of input series) encoder_importance = self._encoder_importance @@ -501,7 +501,8 @@ def _get_importance( # transform the encoder/decoder weights to percentages, rounded to n_decimals weights_percentage = ( - weight.detach().numpy().mean(axis=1).squeeze(axis=1).round(n_decimals) * 100 + weight.detach().cpu().numpy().mean(axis=1).squeeze(axis=1).round(n_decimals) + * 100 ) # create a dataframe with the variable names and the weights name_mapping = self._name_mapping