From ba61337bd0e5410c04cc708be57affc191a8c424 Mon Sep 17 00:00:00 2001 From: Xun Zheng Date: Wed, 16 Dec 2020 21:29:35 -0500 Subject: [PATCH] Update nonlinear.py --- notears/nonlinear.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/notears/nonlinear.py b/notears/nonlinear.py index 188b439..cc6d15a 100644 --- a/notears/nonlinear.py +++ b/notears/nonlinear.py @@ -1,5 +1,6 @@ from notears.locally_connected import LocallyConnected from notears.lbfgsb_scipy import LBFGSBScipy +from notears.trace_expm import trace_expm import torch import torch.nn as nn import numpy as np @@ -52,10 +53,11 @@ def h_func(self): fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight # [j * m1, i] fc1_weight = fc1_weight.view(d, -1, d) # [j, m1, i] A = torch.sum(fc1_weight * fc1_weight, dim=1).t() # [i, j] - # h = trace_expm(A) - d # (Zheng et al. 2018) - M = torch.eye(d) + A / d # (Yu et al. 2019) - E = torch.matrix_power(M, d - 1) - h = (E.t() * M).sum() - d + h = trace_expm(A) - d # (Zheng et al. 2018) + # A different formulation, slightly faster at the cost of numerical stability + # M = torch.eye(d) + A / d # (Yu et al. 2019) + # E = torch.matrix_power(M, d - 1) + # h = (E.t() * M).sum() - d return h def l2_reg(self): @@ -130,10 +132,11 @@ def h_func(self): fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight # [j, ik] fc1_weight = fc1_weight.view(self.d, self.d, self.k) # [j, i, k] A = torch.sum(fc1_weight * fc1_weight, dim=2).t() # [i, j] - # h = trace_expm(A) - d # (Zheng et al. 2018) - M = torch.eye(self.d) + A / self.d # (Yu et al. 2019) - E = torch.matrix_power(M, self.d - 1) - h = (E.t() * M).sum() - self.d + h = trace_expm(A) - d # (Zheng et al. 2018) + # A different formulation, slightly faster at the cost of numerical stability + # M = torch.eye(self.d) + A / self.d # (Yu et al. 2019) + # E = torch.matrix_power(M, self.d - 1) + # h = (E.t() * M).sum() - self.d return h def l2_reg(self):