@@ -372,7 +372,7 @@ def _cal_hg_dynamic(
372372 # n_edge x e_dim
373373 flat_edge_ebd = flat_edge_ebd * flat_sw .unsqueeze (- 1 )
374374 # n_edge x 3 x e_dim
375- flat_h2g2 = (flat_h2 [..., None ] * flat_edge_ebd [:, None , :] ).reshape (
375+ flat_h2g2 = (flat_h2 . unsqueeze ( - 1 ) * flat_edge_ebd . unsqueeze ( - 2 ) ).reshape (
376376 [- 1 , 3 * e_dim ]
377377 )
378378 # nf x nloc x 3 x e_dim
@@ -586,7 +586,9 @@ def optim_angle_update_dynamic(
586586 sub_node_update = paddle .matmul (node_ebd , sub_node )
587587 # n_angle * angle_dim
588588 sub_node_update = paddle .index_select (
589- sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]), n2a_index , 0
589+ sub_node_update .reshape ([nf * nloc , sub_node_update .shape [- 1 ]]),
590+ n2a_index ,
591+ 0 ,
590592 )
591593
592594 # n_edge * angle_dim
@@ -666,7 +668,7 @@ def optim_edge_update_dynamic(
666668 sub_node_update = paddle .matmul (node_ebd , node )
667669 # n_edge * node/edge_dim
668670 sub_node_update = paddle .index_select (
669- sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]),
671+ sub_node_update .reshape ([ nf * nloc , sub_node_update .shape [- 1 ] ]),
670672 n2e_index ,
671673 0 ,
672674 )
@@ -675,7 +677,7 @@ def optim_edge_update_dynamic(
675677 sub_node_ext_update = paddle .matmul (node_ebd_ext , node_ext )
676678 # n_edge * node/edge_dim
677679 sub_node_ext_update = paddle .index_select (
678- sub_node_ext_update .reshape (nf * nall , sub_node_update .shape [- 1 ]),
680+ sub_node_ext_update .reshape ([ nf * nall , sub_node_update .shape [- 1 ] ]),
679681 n_ext2e_index ,
680682 0 ,
681683 )
@@ -698,8 +700,8 @@ def forward(
698700 a_nlist : paddle .Tensor , # nf x nloc x a_nnei
699701 a_nlist_mask : paddle .Tensor , # nf x nloc x a_nnei
700702 a_sw : paddle .Tensor , # switch func, nf x nloc x a_nnei
701- edge_index : paddle .Tensor , # n_edge x 2
702- angle_index : paddle .Tensor , # n_angle x 3
703+ edge_index : paddle .Tensor , # 2 x n_edge
704+ angle_index : paddle .Tensor , # 3 x n_angle
703705 ):
704706 """
705707 Parameters
@@ -724,12 +726,12 @@ def forward(
724726 Masks of the neighbor list for angle. real nei 1 otherwise 0
725727 a_sw : nf x nloc x a_nnei
726728 Switch function for angle.
727- edge_index : Optional for dynamic sel, n_edge x 2
729+ edge_index : Optional for dynamic sel, 2 x n_edge
728730 n2e_index : n_edge
729731 Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
730732 n_ext2e_index : n_edge
731733 Broadcast indices from extended node(j) to edge(ij).
732- angle_index : Optional for dynamic sel, n_angle x 3
734+ angle_index : Optional for dynamic sel, 3 x n_angle
733735 n2a_index : n_angle
734736 Broadcast indices from extended node(j) to angle(ijk).
735737 eij2a_index : n_angle
@@ -746,25 +748,24 @@ def forward(
746748 a_updated : nf x nloc x a_nnei x a_nnei x a_dim
747749 Updated angle embedding.
748750 """
749- nb , nloc , nnei , _ = edge_ebd .shape
751+ nb , nloc , nnei = nlist .shape
750752 nall = node_ebd_ext .shape [1 ]
751753 node_ebd = node_ebd_ext [:, :nloc , :]
752- n_edge = int (nlist_mask .sum ().item ())
753754 if paddle .in_dynamic_mode ():
754755 assert [nb , nloc ] == node_ebd .shape [:2 ]
755756 if not self .use_dynamic_sel :
756757 if paddle .in_dynamic_mode ():
757758 assert [nb , nloc , nnei , 3 ] == h2 .shape
759+ n_edge = None
758760 else :
759- if paddle .in_dynamic_mode ():
760- assert [n_edge , 3 ] == h2 .shape
761+ n_edge = h2 .shape [0 ]
761762 del a_nlist # may be used in the future
762763
763- n2e_index , n_ext2e_index = edge_index [:, 0 ], edge_index [:, 1 ]
764+ n2e_index , n_ext2e_index = edge_index [0 ], edge_index [1 ]
764765 n2a_index , eij2a_index , eik2a_index = (
765- angle_index [:, 0 ],
766- angle_index [:, 1 ],
767- angle_index [:, 2 ],
766+ angle_index [0 ],
767+ angle_index [1 ],
768+ angle_index [2 ],
768769 )
769770
770771 # nb x nloc x nnei x n_dim [OR] n_edge x n_dim
@@ -896,7 +897,7 @@ def forward(
896897 n2e_index ,
897898 average = False ,
898899 num_owner = nb * nloc ,
899- ).reshape (nb , nloc , node_edge_update .shape [- 1 ])
900+ ).reshape ([ nb , nloc , node_edge_update .shape [- 1 ] ])
900901 / self .dynamic_e_sel
901902 )
902903 )
@@ -1042,7 +1043,9 @@ def forward(
10421043 if not self .use_dynamic_sel :
10431044 # nb x nloc x a_nnei x a_nnei x e_dim
10441045 weighted_edge_angle_update = (
1045- a_sw [..., None , None ] * a_sw [..., None , :, None ] * edge_angle_update
1046+ a_sw .unsqueeze (- 1 ).unsqueeze (- 1 )
1047+ * a_sw .unsqueeze (- 2 ).unsqueeze (- 1 )
1048+ * edge_angle_update
10461049 )
10471050 # nb x nloc x a_nnei x e_dim
10481051 reduced_edge_angle_update = paddle .sum (
0 commit comments