Skip to content

Commit

Permalink
fix device error of graph,di_graph,bi_graph,hypergraph laplacian comp…
Browse files Browse the repository at this point in the history
…uting
  • Loading branch information
yifanfeng97 committed Jan 3, 2023
1 parent 1ce68d4 commit 8346349
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
4 changes: 2 additions & 2 deletions dhg/structure/graphs/bipartite_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def D_u(self) -> torch.Tensor:
if self.cache.get("D_u", None) is None:
_tmp = torch.sparse.sum(self.B, dim=1).to_dense().clone().view(-1)
self.cache["D_u"] = torch.sparse_coo_tensor(
indices=torch.arange(0, self.num_u).view(1, -1).repeat(2, 1),
indices=torch.arange(0, self.num_u, device=self.device).view(1, -1).repeat(2, 1),
values=_tmp,
size=torch.Size([self.num_u, self.num_u]),
device=self.device,
Expand All @@ -523,7 +523,7 @@ def D_v(self) -> torch.Tensor:
if self.cache.get("D_v", None) is None:
_tmp = torch.sparse.sum(self.B_T, dim=1).to_dense().clone().view(-1)
self.cache["D_v"] = torch.sparse_coo_tensor(
indices=torch.arange(0, self.num_v).view(1, -1).repeat(2, 1),
indices=torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1),
values=_tmp,
size=torch.Size([self.num_v, self.num_v]),
device=self.device,
Expand Down
4 changes: 2 additions & 2 deletions dhg/structure/graphs/directed_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def D_v_in(self) -> torch.Tensor:
if self.cache.get("D_v_in", None) is None:
_tmp = torch.sparse.sum(self.A_T, dim=1).to_dense().clone().view(-1)
self.cache["D_v_in"] = torch.sparse_coo_tensor(
indices=torch.arange(0, self.num_v).view(1, -1).repeat(2, 1),
indices=torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1),
values=_tmp,
size=torch.Size([self.num_v, self.num_v]),
device=self.device,
Expand All @@ -457,7 +457,7 @@ def D_v_out(self) -> torch.Tensor:
if self.cache.get("D_v_out", None) is None:
_tmp = torch.sparse.sum(self.A, dim=1).to_dense().clone().view(-1)
self.cache["D_v_out"] = torch.sparse_coo_tensor(
indices=torch.arange(0, self.num_v).view(1, -1).repeat(2, 1),
indices=torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1),
values=_tmp,
size=torch.Size([self.num_v, self.num_v]),
device=self.device,
Expand Down
8 changes: 4 additions & 4 deletions dhg/structure/graphs/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,8 @@ def L_sym(self) -> torch.Tensor:
_tmp_g.remove_selfloop()
_L = _tmp_g.D_v_neg_1_2.mm(_tmp_g.A).mm(_tmp_g.D_v_neg_1_2).clone()
self.cache["L_sym"] = torch.sparse_coo_tensor(
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), _L._indices(),]),
torch.hstack([torch.ones(self.num_v), -_L._values()]),
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), _L._indices(),]),
torch.hstack([torch.ones(self.num_v, device=self.device), -_L._values()]),
torch.Size([self.num_v, self.num_v]),
device=self.device,
).coalesce()
Expand All @@ -611,8 +611,8 @@ def L_rw(self) -> torch.Tensor:
_tmp_g.remove_selfloop()
_L = _tmp_g.D_v_neg_1.mm(_tmp_g.A).clone()
self.cache["L_rw"] = torch.sparse_coo_tensor(
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), _L._indices(),]),
torch.hstack([torch.ones(self.num_v), -_L._values()]),
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), _L._indices(),]),
torch.hstack([torch.ones(self.num_v, device=self.device), -_L._values()]),
torch.Size([self.num_v, self.num_v]),
device=self.device,
).coalesce()
Expand Down
30 changes: 15 additions & 15 deletions dhg/structure/hypergraphs/hypergraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def W_v(self) -> torch.Tensor:
_tmp = torch.Tensor(self.v_weight)
_num_v = _tmp.size(0)
self.cache["W_v"] = torch.sparse_coo_tensor(
torch.arange(0, _num_v).view(1, -1).repeat(2, 1),
torch.arange(0, _num_v, device=self.device).view(1, -1).repeat(2, 1),
_tmp,
torch.Size([_num_v, _num_v]),
device=self.device,
Expand All @@ -830,7 +830,7 @@ def W_e(self) -> torch.Tensor:
_tmp = torch.cat(_tmp, dim=0).view(-1)
_num_e = _tmp.size(0)
self.cache["W_e"] = torch.sparse_coo_tensor(
torch.arange(0, _num_e).view(1, -1).repeat(2, 1),
torch.arange(0, _num_e, device=self.device).view(1, -1).repeat(2, 1),
_tmp,
torch.Size([_num_e, _num_e]),
device=self.device,
Expand All @@ -848,7 +848,7 @@ def W_e_of_group(self, group_name: str) -> torch.Tensor:
_tmp = self._fetch_W_of_group(group_name).view(-1)
_num_e = _tmp.size(0)
self.group_cache[group_name]["W_e"] = torch.sparse_coo_tensor(
torch.arange(0, _num_e).view(1, -1).repeat(2, 1),
torch.arange(0, _num_e, device=self.device).view(1, -1).repeat(2, 1),
_tmp,
torch.Size([_num_e, _num_e]),
device=self.device,
Expand All @@ -863,7 +863,7 @@ def D_v(self) -> torch.Tensor:
_tmp = [self.D_v_of_group(name)._values().clone() for name in self.group_names]
_tmp = torch.vstack(_tmp).sum(dim=0).view(-1)
self.cache["D_v"] = torch.sparse_coo_tensor(
torch.arange(0, self.num_v).view(1, -1).repeat(2, 1),
torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1),
_tmp,
torch.Size([self.num_v, self.num_v]),
device=self.device,
Expand All @@ -885,7 +885,7 @@ def D_v_of_group(self, group_name: str) -> torch.Tensor:
_tmp = torch.sparse.sum(H_, dim=1).to_dense().clone().view(-1)
_num_v = _tmp.size(0)
self.group_cache[group_name]["D_v"] = torch.sparse_coo_tensor(
torch.arange(0, _num_v).view(1, -1).repeat(2, 1),
torch.arange(0, _num_v, device=self.device).view(1, -1).repeat(2, 1),
_tmp,
torch.Size([_num_v, _num_v]),
device=self.device,
Expand Down Expand Up @@ -959,7 +959,7 @@ def D_e(self) -> torch.Tensor:
_tmp = torch.cat(_tmp, dim=0).view(-1)
_num_e = _tmp.size(0)
self.cache["D_e"] = torch.sparse_coo_tensor(
torch.arange(0, _num_e).view(1, -1).repeat(2, 1),
torch.arange(0, _num_e, device=self.device).view(1, -1).repeat(2, 1),
_tmp,
torch.Size([_num_e, _num_e]),
device=self.device,
Expand All @@ -977,7 +977,7 @@ def D_e_of_group(self, group_name: str) -> torch.Tensor:
_tmp = torch.sparse.sum(self.H_T_of_group(group_name), dim=1).to_dense().clone().view(-1)
_num_e = _tmp.size(0)
self.group_cache[group_name]["D_e"] = torch.sparse_coo_tensor(
torch.arange(0, _num_e).view(1, -1).repeat(2, 1),
torch.arange(0, _num_e, device=self.device).view(1, -1).repeat(2, 1),
_tmp,
torch.Size([_num_e, _num_e]),
device=self.device,
Expand Down Expand Up @@ -1090,8 +1090,8 @@ def L_sym(self) -> torch.Tensor:
if self.cache.get("L_sym") is None:
L_HGNN = self.L_HGNN.clone()
self.cache["L_sym"] = torch.sparse_coo_tensor(
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), L_HGNN._indices(),]),
torch.hstack([torch.ones(self.num_v), -L_HGNN._values()]),
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), L_HGNN._indices(),]),
torch.hstack([torch.ones(self.num_v, device=self.device), -L_HGNN._values()]),
torch.Size([self.num_v, self.num_v]),
device=self.device,
).coalesce()
Expand All @@ -1110,8 +1110,8 @@ def L_sym_of_group(self, group_name: str) -> torch.Tensor:
if self.group_cache[group_name].get("L_sym") is None:
L_HGNN = self.L_HGNN_of_group(group_name).clone()
self.group_cache[group_name]["L_sym"] = torch.sparse_coo_tensor(
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), L_HGNN._indices(),]),
torch.hstack([torch.ones(self.num_v), -L_HGNN._values()]),
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), L_HGNN._indices(),]),
torch.hstack([torch.ones(self.num_v, device=self.device), -L_HGNN._values()]),
torch.Size([self.num_v, self.num_v]),
device=self.device,
).coalesce()
Expand All @@ -1128,8 +1128,8 @@ def L_rw(self) -> torch.Tensor:
_tmp = self.D_v_neg_1.mm(self.H).mm(self.W_e).mm(self.D_e_neg_1).mm(self.H_T)
self.cache["L_rw"] = (
torch.sparse_coo_tensor(
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), _tmp._indices(),]),
torch.hstack([torch.ones(self.num_v), -_tmp._values()]),
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), _tmp._indices(),]),
torch.hstack([torch.ones(self.num_v, device=self.device), -_tmp._values()]),
torch.Size([self.num_v, self.num_v]),
device=self.device,
)
Expand Down Expand Up @@ -1158,8 +1158,8 @@ def L_rw_of_group(self, group_name: str) -> torch.Tensor:
)
self.group_cache[group_name]["L_rw"] = (
torch.sparse_coo_tensor(
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), _tmp._indices(),]),
torch.hstack([torch.ones(self.num_v), -_tmp._values()]),
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), _tmp._indices(),]),
torch.hstack([torch.ones(self.num_v, device=self.device), -_tmp._values()]),
torch.Size([self.num_v, self.num_v]),
device=self.device,
)
Expand Down

0 comments on commit 8346349

Please sign in to comment.