From 5ab9bcf72531362413455945380f798106998f89 Mon Sep 17 00:00:00 2001 From: "siqiao.xsq" Date: Wed, 8 May 2024 16:21:41 +0800 Subject: [PATCH] fix bugs in parameter initialization in RMTPP --- easy_tpp/model/torch_model/torch_rmtpp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/easy_tpp/model/torch_model/torch_rmtpp.py b/easy_tpp/model/torch_model/torch_rmtpp.py index 0e9fc14..da327f5 100644 --- a/easy_tpp/model/torch_model/torch_rmtpp.py +++ b/easy_tpp/model/torch_model/torch_rmtpp.py @@ -24,8 +24,8 @@ def __init__(self, model_config): self.layer_hidden = nn.Linear(self.hidden_size, self.num_event_types) - self.factor_intensity_base = torch.empty([1, 1, self.num_event_types], device=self.device) - self.factor_intensity_current_influence = torch.empty([1, 1, self.num_event_types], device=self.device) + self.factor_intensity_base = torch.nn.Parameter(torch.empty([1, 1, self.num_event_types], device=self.device)) + self.factor_intensity_current_influence = torch.nn.Parameter(torch.empty([1, 1, self.num_event_types], device=self.device)) nn.init.xavier_normal_(self.factor_intensity_base) nn.init.xavier_normal_(self.factor_intensity_current_influence)