Skip to content

Commit

Permalink
Match escn_exportable with escn main (#866)
Browse files Browse the repository at this point in the history
* fix l/m; make rescaling optional

* remove mapping reduced from tests, hoping we wont need to register buffer
  • Loading branch information
misko authored Sep 20, 2024
1 parent 6bd1888 commit 83fd9d2
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 28 deletions.
22 changes: 8 additions & 14 deletions src/fairchem/core/models/escn/escn_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
resolution: int | None = None,
compile: bool = False,
export: bool = False,
rescale_grid: bool = False,
) -> None:
super().__init__()

Expand All @@ -103,6 +104,7 @@ def __init__(
self.distance_function = distance_function
self.compile = compile
self.export = export
self.rescale_grid = rescale_grid

# non-linear activation function used throughout the network
self.act = nn.SiLU()
Expand Down Expand Up @@ -152,12 +154,11 @@ def __init__(
# Initialize the transformations between spherical and grid representations
self.SO3_grid = nn.ModuleDict()
self.SO3_grid["lmax_lmax"] = SO3_Grid(
self.lmax, self.lmax, resolution=resolution
self.lmax, self.lmax, resolution=resolution, rescale=self.rescale_grid
)
self.SO3_grid["lmax_mmax"] = SO3_Grid(
self.lmax, self.mmax, resolution=resolution
self.lmax, self.mmax, resolution=resolution, rescale=self.rescale_grid
)
self.mappingReduced = CoefficientMapping([self.lmax], [self.mmax])

# Initialize the blocks for each layer of the GNN
self.layer_blocks = nn.ModuleList()
Expand All @@ -173,7 +174,6 @@ def __init__(
self.max_num_elements,
self.SO3_grid,
self.act,
self.mappingReduced,
)
self.layer_blocks.append(block)

Expand Down Expand Up @@ -435,7 +435,6 @@ def __init__(
max_num_elements: int,
SO3_grid: SO3_Grid,
act,
mappingReduced,
) -> None:
super().__init__()
self.layer_idx = layer_idx
Expand All @@ -444,7 +443,6 @@ def __init__(
self.mmax = mmax
self.sphere_channels = sphere_channels
self.SO3_grid = SO3_grid
self.mappingReduced = mappingReduced

# Message block
self.message_block = MessageBlock(
Expand All @@ -458,7 +456,6 @@ def __init__(
max_num_elements,
self.SO3_grid,
self.act,
self.mappingReduced,
)

# Non-linear point-wise comvolution for the aggregated messages
Expand Down Expand Up @@ -547,7 +544,6 @@ def __init__(
max_num_elements: int,
SO3_grid: SO3_Grid,
act,
mappingReduced,
) -> None:
super().__init__()
self.layer_idx = layer_idx
Expand All @@ -558,8 +554,9 @@ def __init__(
self.lmax = lmax
self.mmax = mmax
self.edge_channels = edge_channels
self.mappingReduced = mappingReduced
self.out_mask = self.mappingReduced.coefficient_idx(self.lmax, self.mmax)
self.out_mask = CoefficientMapping([self.lmax], [self.lmax]).coefficient_idx(
self.lmax, self.mmax
)

# Create edge scalar (invariant to rotations) features
self.edge_block = EdgeBlock(
Expand All @@ -577,7 +574,6 @@ def __init__(
self.lmax,
self.mmax,
self.act,
self.mappingReduced,
)
self.so2_block_target = SO2Block(
self.sphere_channels,
Expand All @@ -586,7 +582,6 @@ def __init__(
self.lmax,
self.mmax,
self.act,
self.mappingReduced,
)

def forward(
Expand Down Expand Up @@ -666,15 +661,14 @@ def __init__(
lmax: int,
mmax: int,
act,
mappingReduced,
) -> None:
super().__init__()
self.sphere_channels = sphere_channels
self.hidden_channels = hidden_channels
self.lmax = lmax
self.mmax = mmax
self.act = act
self.mappingReduced = mappingReduced
self.mappingReduced = CoefficientMapping([self.lmax], [self.mmax])

num_channels_m0 = (self.lmax + 1) * self.sphere_channels

Expand Down
1 change: 0 additions & 1 deletion src/fairchem/core/models/escn/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ def _grid_act(self, SO3_grid, act, mappingReduced) -> None:
from_grid_mat = SO3_grid[self.lmax_list[i]][
self.mmax_list[i]
].get_from_grid_mat(self.device)

x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_res)
x_grid = act(x_grid)
x_res = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid)
Expand Down
5 changes: 3 additions & 2 deletions src/fairchem/core/models/escn/so3_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(
mmax: int,
normalization: str = "integral",
resolution: int | None = None,
rescale: bool = False,
):
super().__init__()

Expand All @@ -276,7 +277,7 @@ def __init__(
)
to_grid_mat = torch.einsum("mbi, am -> bai", to_grid.shb, to_grid.sha).detach()
# rescale based on mmax
if lmax != mmax:
if rescale and lmax != mmax:
for lval in range(lmax + 1):
if lval <= mmax:
continue
Expand All @@ -300,7 +301,7 @@ def __init__(
"am, mbi -> bai", from_grid.sha, from_grid.shb
).detach()
# rescale based on mmax
if lmax != mmax:
if rescale and lmax != mmax:
for lval in range(lmax + 1):
if lval <= mmax:
continue
Expand Down
15 changes: 4 additions & 11 deletions tests/core/models/test_escn_compiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from fairchem.core.datasets import data_list_collater
from fairchem.core.models.escn import escn_exportable
from fairchem.core.models.escn.so3_exportable import (
CoefficientMapping,
SO3_Grid,
)
from fairchem.core.models.scn.smearing import GaussianSmearing
Expand Down Expand Up @@ -70,8 +69,8 @@ def load_model(type: str, compile=False, export=False):
cutoff=CUTOFF,
max_num_elements=MAX_ELEMENTS,
num_layers=8,
lmax_list=[4],
mmax_list=[2],
lmax_list=[6],
mmax_list=[0],
sphere_channels=128,
hidden_channels=256,
edge_channels=128,
Expand All @@ -87,8 +86,8 @@ def load_model(type: str, compile=False, export=False):
cutoff=CUTOFF,
max_num_elements=MAX_ELEMENTS,
num_layers=8,
lmax=4,
mmax=2,
lmax=6,
mmax=0,
sphere_channels=128,
hidden_channels=256,
edge_channels=128,
Expand Down Expand Up @@ -210,7 +209,6 @@ def test_escn_so2_conv_exports_and_compiles(self, tol=1e-5) -> None:
}

lmax, mmax = 4, 2
mappingReduced = escn_exportable.CoefficientMapping([lmax], [mmax])
shpere_channels = 128
edge_channels = 128
args = (torch.rand(680, 19, shpere_channels), torch.rand(680, edge_channels))
Expand All @@ -222,7 +220,6 @@ def test_escn_so2_conv_exports_and_compiles(self, tol=1e-5) -> None:
lmax=lmax,
mmax=mmax,
act=torch.nn.SiLU(),
mappingReduced=mappingReduced,
)
prog = export(so2, args=args, dynamic_shapes=dynamic_shapes1)
export_out = prog.module()(*args)
Expand All @@ -244,7 +241,6 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None:
SO3_grid = torch.nn.ModuleDict()
SO3_grid["lmax_lmax"] = SO3_Grid(lmax, lmax)
SO3_grid["lmax_mmax"] = SO3_Grid(lmax, mmax)
mappingReduced = CoefficientMapping([lmax], [mmax])
message_block = escn_exportable.MessageBlock(
layer_idx=0,
sphere_channels=sphere_channels,
Expand All @@ -256,7 +252,6 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None:
max_num_elements=90,
SO3_grid=SO3_grid,
act=torch.nn.SiLU(),
mappingReduced=mappingReduced,
)

# generate inputs
Expand Down Expand Up @@ -297,7 +292,6 @@ def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None:
SO3_grid = torch.nn.ModuleDict()
SO3_grid["lmax_lmax"] = SO3_Grid(lmax, lmax)
SO3_grid["lmax_mmax"] = SO3_Grid(lmax, mmax)
mappingReduced = CoefficientMapping([lmax], [mmax])
layer_block = escn_exportable.LayerBlock(
layer_idx=0,
sphere_channels=sphere_channels,
Expand All @@ -309,7 +303,6 @@ def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None:
max_num_elements=90,
SO3_grid=SO3_grid,
act=torch.nn.SiLU(),
mappingReduced=mappingReduced,
)

# generate inputs
Expand Down

0 comments on commit 83fd9d2

Please sign in to comment.