From 3f07706a6b21e46c3106f52d87298c8f3f357d54 Mon Sep 17 00:00:00 2001
From: Trang Thoi <103948755+tdktrang@users.noreply.github.com>
Date: Sat, 18 May 2024 18:08:56 +0900
Subject: [PATCH 1/3] Update inference.py

Related to DeepLearningExamples/PyTorch/Forecasting/TFT/
(e.g. GNMT/PyTorch or FasterTransformer/All)

Describe the bug
When I run the "inference.py" the error happen because "unscaled_predictions" was numpy.ndarray. Therefore, we need to add the code to process the unscaled_predictions to tensor

To Reproduce
Steps to reproduce the behavior:

python inference.py \
--checkpoint /results/TFT_electricity_bs8x1024_lr1e-3/seed_1/checkpoint.pt \
--data /data/processed/electricity_bin/test.csv \
--tgt_scalers /data/processed/electricity_bin/tgt_scalers.bin \
--cat_encodings /data/processed/electricity_bin/cat_encodings.bin \
--visualize \
--save_predictions
Expected behavior
'numpy.ndarray' object has no attribute 'new_full'

Environment

GPUs in the system: NVIDIA GeForce RTX 3090
CUDA driver version 520.61.05
---
 PyTorch/Forecasting/TFT/inference.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/PyTorch/Forecasting/TFT/inference.py b/PyTorch/Forecasting/TFT/inference.py
index 7f60f5588..a4ed0daff 100644
--- a/PyTorch/Forecasting/TFT/inference.py
+++ b/PyTorch/Forecasting/TFT/inference.py
@@ -111,7 +111,8 @@ def predict(args, config, model, data_loader, scalers, cat_encodings, extend_tar
 
 def visualize_v2(args, config, model, data_loader, scalers, cat_encodings):
     unscaled_predictions, unscaled_targets, ids, _ = predict(args, config, model, data_loader, scalers, cat_encodings, extend_targets=True)
-
+    unscaled_predictions = torch.tensor(unscaled_predictions)
+    unscaled_targets = torch.tensor(unscaled_targets)
     num_horizons = config.example_length - config.encoder_length + 1
     pad = unscaled_predictions.new_full((unscaled_targets.shape[0], unscaled_targets.shape[1] - unscaled_predictions.shape[1], unscaled_predictions.shape[2]), fill_value=float('nan'))
     pad[:,-1,:] = unscaled_targets[:,-num_horizons,:]

From 879d294d686c45e49efcd7bb8dfa35492202ef69 Mon Sep 17 00:00:00 2001
From: Trang Thoi <103948755+tdktrang@users.noreply.github.com>
Date: Tue, 21 May 2024 09:06:51 +0900
Subject: [PATCH 2/3] Update inference.py

---
 PyTorch/Forecasting/TFT/inference.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/PyTorch/Forecasting/TFT/inference.py b/PyTorch/Forecasting/TFT/inference.py
index a4ed0daff..cbbffe4e4 100644
--- a/PyTorch/Forecasting/TFT/inference.py
+++ b/PyTorch/Forecasting/TFT/inference.py
@@ -139,6 +139,8 @@ def inference(args, config, model, data_loader, scalers, cat_encodings):
     if args.joint_visualization or args.save_predictions:
         ids = torch.from_numpy(ids.squeeze())
         #ids = torch.cat([x['id'][0] for x in data_loader.dataset])
+        unscaled_predictions = torch.tensor(unscaled_predictions)
+        unscaled_targets = torch.tensor(unscaled_targets)
         joint_graphs = torch.cat([unscaled_targets, unscaled_predictions], dim=2)
         graphs = {i:joint_graphs[ids == i, :, :] for i in set(ids.tolist())}
         for key, g in graphs.items(): #timeseries id, joint targets and predictions

From 5590978575e0f1de6c051a56785eda4ed4671050 Mon Sep 17 00:00:00 2001
From: Trang Thoi <103948755+tdktrang@users.noreply.github.com>
Date: Tue, 21 May 2024 09:16:09 +0900
Subject: [PATCH 3/3] Update criterions.py

---
 PyTorch/Forecasting/TFT/criterions.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/PyTorch/Forecasting/TFT/criterions.py b/PyTorch/Forecasting/TFT/criterions.py
index 12de5be76..566467a23 100644
--- a/PyTorch/Forecasting/TFT/criterions.py
+++ b/PyTorch/Forecasting/TFT/criterions.py
@@ -29,6 +29,10 @@ def forward(self, predictions, targets):
         return losses
 
 def qrisk(pred, tgt, quantiles):
+    if isinstance(pred, torch.Tensor):
+        pred = pred.detach().cpu().numpy()
+    if isinstance(tgt, torch.Tensor):
+        tgt = tgt.detach().cpu().numpy()
     diff = pred - tgt
     ql = (1-quantiles)*np.clip(diff,0, float('inf')) + quantiles*np.clip(-diff,0, float('inf'))
     losses = ql.reshape(-1, ql.shape[-1])