Skip to content

【Hackathon 8th No.14】CoNFiLD 论文复现 #1110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
369 changes: 369 additions & 0 deletions docs/zh/examples/confild.md

Large diffs are not rendered by default.

Binary file added docs/zh/examples/confild.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
108 changes: 108 additions & 0 deletions examples/confild/conf/confild_case1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_

hydra:
run:
# dynamic output directory according to running time and override name
# dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
dir: ./outputs_confild_case1
job:
name: ${mode} # name of logfile
chdir: false # keep current working directory unchanged
callbacks:
init_callback:
_target_: ppsci.utils.callbacks.InitCallback
sweep:
# output directory for multirun
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: infer # running mode: infer
seed: 2025
output_dir: ${hydra:run.dir}
log_freq: 20

TRAIN:
batch_size: 64
test_batch_size: 256
epochs: 9800
mutil_GPU: 1
lr:
cnf: 1.e-4
latents: 1.e-5

EVAL:
confild_pretrained_model_path: ./outputs_confild_case1/confild_case1/epoch_99999
latent_pretrained_model_path: ./outputs_confild_case1/latent_case1/epoch_99999

CONFILD:
input_keys: ["confild_x", "latent_z"]
output_keys: ["confild_output"]
num_hidden_layers: 10
out_features: 3
hidden_features: 128
in_coord_features: 2
in_latent_features: 128

Latent:
input_keys: ["latent_x"]
output_keys: ["latent_z"]
N_samples: 16000
lumped: True
N_features: 128
dims: 2

INFER:
Latent:
INFER:
pretrained_model_path: null
export_path: ./inference/latent_case1
pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
onnx_path: ${INFER.Latent.INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 1024
num_cpu_threads: 10
log_freq: 20
Confild:
INFER:
pretrained_model_path: null
export_path: ./inference/confild_case1
pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
onnx_path: ${INFER.Confild.INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 1024
num_cpu_threads: 10
coord_shape: [918, 2]
latents_shape: [1, 128]
log_freq: 20
batch_size: 64

Data:
data_path: ../case1/data.npy
coor_path: ../case1/coor.npy
normalizer:
method: "-11"
dim: 0
load_data_fn: load_elbow_flow
108 changes: 108 additions & 0 deletions examples/confild/conf/confild_case2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_

hydra:
run:
# dynamic output directory according to running time and override name
# dir: outputs_confild_case2/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
dir: ./outputs_confild_case2
job:
name: ${mode} # name of logfile
chdir: false # keep current working directory unchanged
callbacks:
init_callback:
_target_: ppsci.utils.callbacks.InitCallback
sweep:
# output directory for multirun
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: infer # running mode: infer
seed: 2025
output_dir: ${hydra:run.dir}
log_freq: 20

TRAIN:
batch_size: 40
test_batch_size: 40
epochs: 44500
mutil_GPU: 1
lr:
cnf: 1.e-4
latents: 1.e-5

EVAL:
confild_pretrained_model_path: ./outputs_confild_case2/confild_case2/epoch_99999
latent_pretrained_model_path: ./outputs_confild_case2/latent_case2/epoch_99999

CONFILD:
input_keys: ["confild_x", "latent_z"]
output_keys: ["confild_output"]
num_hidden_layers: 10
out_features: 4
hidden_features: 256
in_coord_features: 2
in_latent_features: 256

Latent:
input_keys: ["latent_x"]
output_keys: ["latent_z"]
N_samples: 1200
lumped: False
N_features: 256
dims: 2

INFER:
Latent:
INFER:
pretrained_model_path: null
export_path: ./inference/latent_case2
pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
onnx_path: ${INFER.Latent.INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 1024
num_cpu_threads: 10
log_freq: 20
Confild:
INFER:
pretrained_model_path: null
export_path: ./inference/confild_case2
pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
onnx_path: ${INFER.Confild.INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 1024
num_cpu_threads: 10
coord_shape: [400, 100, 2]
latents_shape: [1, 1, 256]
log_freq: 20
batch_size: 40

Data:
data_path: ../case2/data.npy
coor_path: null
normalizer:
method: "-11"
dim: 0
load_data_fn: load_channel_flow
108 changes: 108 additions & 0 deletions examples/confild/conf/confild_case3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_

hydra:
run:
# dynamic output directory according to running time and override name
# dir: outputs_confild_case3/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
dir: ./outputs_confild_case3
job:
name: ${mode} # name of logfile
chdir: false # keep current working directory unchanged
callbacks:
init_callback:
_target_: ppsci.utils.callbacks.InitCallback
sweep:
# output directory for multirun
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: infer # running mode: infer
seed: 2025
output_dir: ${hydra:run.dir}
log_freq: 20

TRAIN:
batch_size: 100
test_batch_size: 100
epochs: 4800
mutil_GPU: 2
lr:
cnf: 1.e-4
latents: 1.e-5

EVAL:
confild_pretrained_model_path: ./outputs_confild_case3/confild_case3/epoch_99999
latent_pretrained_model_path: ./outputs_confild_case3/latent_case3/epoch_99999

CONFILD:
input_keys: ["confild_x", "latent_z"]
output_keys: ["confild_output"]
num_hidden_layers: 117
out_features: 2
hidden_features: 256
in_coord_features: 2
in_latent_features: 256

Latent:
input_keys: ["latent_x"]
output_keys: ["latent_z"]
N_samples: 2880
lumped: True
N_features: 256
dims: 2

INFER:
Latent:
INFER:
pretrained_model_path: null
export_path: ./inference/latent_case3
pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
onnx_path: ${INFER.Latent.INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 1024
num_cpu_threads: 10
log_freq: 20
Confild:
INFER:
pretrained_model_path: null
export_path: ./inference/confild_case3
pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
onnx_path: ${INFER.Confild.INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 1024
num_cpu_threads: 10
coord_shape: [10884, 2]
latents_shape: [1, 256]
log_freq: 20
batch_size: 100

Data:
data_path: ../case3/data.npy
coor_path: ../case3/coor.npy
normalizer:
method: "-11"
dim: 0
load_data_fn: load_periodic_hill_flow
108 changes: 108 additions & 0 deletions examples/confild/conf/confild_case4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_

hydra:
run:
# dynamic output directory according to running time and override name
# dir: outputs_confild_case4/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
dir: ./outputs_confild_case4
job:
name: ${mode} # name of logfile
chdir: false # keep current working directory unchanged
callbacks:
init_callback:
_target_: ppsci.utils.callbacks.InitCallback
sweep:
# output directory for multirun
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: infer # running mode: infer
seed: 2025
output_dir: ${hydra:run.dir}
log_freq: 20

TRAIN:
batch_size: 4
test_batch_size: 4
epochs: 20000
mutil_GPU: 2
lr:
cnf: 1.e-4
latents: 1.e-5

EVAL:
confild_pretrained_model_path: ./outputs_confild_case4/confild_case4/epoch_99999
latent_pretrained_model_path: ./outputs_confild_case4/latent_case4/epoch_99999

CONFILD:
input_keys: ["confild_x", "latent_z"]
output_keys: ["confild_output"]
num_hidden_layers: 15
out_features: 3
hidden_features: 384
in_coord_features: 3
in_latent_features: 384

Latent:
input_keys: ["latent_x"]
output_keys: ["latent_z"]
N_samples: 1200
lumped: True
N_features: 384
dims: 3

INFER:
Latent:
INFER:
pretrained_model_path: null
export_path: ./inference/latent_case4
pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
onnx_path: ${INFER.Latent.INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 1024
num_cpu_threads: 10
log_freq: 20
Confild:
INFER:
pretrained_model_path: null
export_path: ./inference/confild_case4
pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
onnx_path: ${INFER.Confild.INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 1024
num_cpu_threads: 10
coord_shape: [58483, 3]
latents_shape: [1, 384]
log_freq: 20
batch_size: 4

Data:
data_path: ../case4/data.npy
coor_path: ../case4/coor.npy
normalizer:
method: "-11"
dim: 0
load_data_fn: load_3d_flow
562 changes: 562 additions & 0 deletions examples/confild/confild.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions ppsci/arch/__init__.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
from ppsci.arch.base import Arch # isort:skip
from ppsci.arch.cfdgcn import CFDGCN # isort:skip
from ppsci.arch.chip_deeponets import ChipDeepONets # isort:skip
from ppsci.arch.confild import LatentContainer, SIRENAutodecoder_film # isort:skip
from ppsci.arch.crystalgraphconvnet import CrystalGraphConvNet # isort:skip
from ppsci.arch.cuboid_transformer import CuboidTransformer # isort:skip
from ppsci.arch.cvit import CVit # isort:skip
@@ -88,6 +89,7 @@
"Generator",
"GraphCastNet",
"HEDeepONets",
"LatentContainer",
"LorenzEmbedding",
"LNO",
"MLP",
@@ -100,6 +102,7 @@
"PrecipNet",
"RosslerEmbedding",
"SFNONet",
"SIRENAutodecoder_film",
"SPINN",
"TFNO1dNet",
"TFNO2dNet",
380 changes: 380 additions & 0 deletions ppsci/arch/confild.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,380 @@
import math
from collections import OrderedDict

import numpy as np
import paddle

DEFAULT_W0 = 30.0


class Swish(paddle.nn.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请判断下是否可以复用:ppsci/arc/activation.py的swish?

def __init__(self):
super().__init__()
self.Sigmoid = paddle.nn.Sigmoid()

def forward(self, x):
return x * self.Sigmoid(x)


class Sine(paddle.nn.Layer):
def __init__(self, w0=DEFAULT_W0):
self.w0 = w0
super().__init__()

def forward(self, input):
return paddle.sin(x=self.w0 * input)


def sine_init(m, w0=DEFAULT_W0):
with paddle.no_grad():
if hasattr(m, "weight"):
num_input = m.weight.shape[-1]
m.weight.uniform_(
min=-math.sqrt(6 / num_input) / w0, max=math.sqrt(6 / num_input) / w0
)


def first_layer_sine_init(m):
with paddle.no_grad():
if hasattr(m, "weight"):
num_input = m.weight.shape[-1]
m.weight.uniform_(min=-1 / num_input, max=1 / num_input)


def __check_Linear_weight(m):
if isinstance(m, paddle.nn.Linear):
if hasattr(m, "weight"):
return True
return False


def init_weights_normal(m):
if __check_Linear_weight(m):
init_KaimingNormal = paddle.nn.initializer.KaimingNormal(
nonlinearity="relu", negative_slope=0.0
)
init_KaimingNormal(m.weight)


def init_weights_selu(m):
if __check_Linear_weight(m):
num_input = m.weight.shape[-1]
init_Normal = paddle.nn.initializer.Normal(std=1 / math.sqrt(num_input))
init_Normal(m.weight)


def init_weights_elu(m):
if __check_Linear_weight(m):
num_input = m.weight.shape[-1]
init_Normal = paddle.nn.initializer.Normal(
std=math.sqrt(1.5505188080679277) / math.sqrt(num_input)
)
init_Normal(m.weight)


def init_weights_xavier(m):
if __check_Linear_weight(m):
init_XavierNormal = paddle.nn.initializer.XavierNormal()
init_XavierNormal(m.weight)


NLS_AND_INITS = {
"sine": (Sine(), sine_init, first_layer_sine_init),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

"relu": (paddle.nn.ReLU(), init_weights_normal, None),
"sigmoid": (paddle.nn.Sigmoid(), init_weights_xavier, None),
"tanh": (paddle.nn.Tanh(), init_weights_xavier, None),
"selu": (paddle.nn.SELU(), init_weights_selu, None),
"softplus": (paddle.nn.Softplus(), init_weights_normal, None),
"elu": (paddle.nn.ELU(), init_weights_elu, None),
"swish": (Swish(), init_weights_xavier, None),
}


class BatchLinear(paddle.nn.Linear):
"""
This is a linear transformation implemented manually. It also allows maually input parameters.
for initialization, (in_features, out_features) needs to be provided.
weight is of shape (out_features*in_features)
bias is of shape (out_features)
"""

__doc__ = paddle.nn.Linear.__doc__

def forward(self, input, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get("bias", None)
weight = params["weight"]

output = paddle.matmul(x=input, y=weight)
if bias is not None:
output += bias.unsqueeze(axis=-2)
return output


class FeatureMapping:
"""
This is feature mapping class for fourier feature networks
"""

def __init__(
self,
in_features,
mode="basic",
gaussian_mapping_size=256,
gaussian_rand_key=0,
gaussian_tau=1.0,
pe_num_freqs=4,
pe_scale=2,
pe_init_scale=1,
pe_use_nyquist=True,
pe_lowest_dim=None,
rbf_out_features=None,
rbf_range=1.0,
rbf_std=0.5,
):
"""
inputs:
in_freatures: number of input features
mapping_size: output features for Gaussian mapping
rand_key: random key for Gaussian mapping
tau: standard deviation for Gaussian mapping
num_freqs: number of frequencies for P.E.
scale = 2: base scale of frequencies for P.E.
init_scale: initial scale for P.E.
use_nyquist: use nyquist to calculate num_freqs or not.
"""
self.mode = mode
if mode == "basic":
self.B = np.eye(in_features)
elif mode == "gaussian":
rng = np.random.default_rng(gaussian_rand_key)
self.B = rng.normal(
loc=0.0, scale=gaussian_tau, size=(gaussian_mapping_size, in_features)
)
elif mode == "positional":
if pe_use_nyquist == "True" and pe_lowest_dim:
pe_num_freqs = self.get_num_frequencies_nyquist(pe_lowest_dim)
self.B = pe_init_scale * np.vstack(
[(pe_scale**i * np.eye(in_features)) for i in range(pe_num_freqs)]
)
self.dim = tuple(self.B.shape)[0] * 2
elif mode == "rbf":
self.centers = paddle.base.framework.EagerParamBase.from_tensor(
tensor=paddle.empty(
shape=(rbf_out_features, in_features), dtype="float32"
)
)
self.sigmas = paddle.base.framework.EagerParamBase.from_tensor(
tensor=paddle.empty(shape=rbf_out_features, dtype="float32")
)
init_Uniform = paddle.nn.initializer.Uniform(
low=-1 * rbf_range, high=rbf_range
)
init_Uniform(self.centers)
init_Constant = paddle.nn.initializer.Constant(value=rbf_std)
init_Constant(self.sigmas)

def __call__(self, input):
if self.mode in ["basic", "gaussian", "positional"]:
return self.fourier_mapping(input, self.B)
elif self.mode == "rbf":
return self.rbf_mapping(input)

def get_num_frequencies_nyquist(self, samples):
nyquist_rate = 1 / (2 * (2 * 1 / samples))
return int(math.floor(math.log(nyquist_rate, 2)))

@staticmethod
def fourier_mapping(x, B):
"""
x is the input, B is the reference information
"""
if B is None:
return x
else:
B = paddle.to_tensor(data=B, dtype="float32", place=x.place)
x_proj = 2.0 * np.pi * x @ B.T
return paddle.concat(
x=[paddle.sin(x=x_proj), paddle.cos(x=x_proj)], axis=-1
)

def rbf_mapping(self, x):
size = tuple(x.shape)[:-1] + tuple(self.centers.shape)
x = x.unsqueeze(axis=-2).expand(shape=size)
distances = (x - self.centers).pow(y=2).sum(axis=-1) * self.sigmas
return self.gaussian(distances)

@staticmethod
def gaussian(alpha):
phi = paddle.exp(x=-1 * alpha.pow(y=2))
return phi


class SIRENAutodecoder_film(paddle.nn.Layer):
"""
siren network with author decoding
Args:
input_keys (Tuple[str,...], optional): Key to get the input tensor from the dict.
output_keys (Tuple[str,...], optional): Key to save the output tensor into the dict.
in_coord_features (int, optional): Number of input coordinates features
in_latent_features (int, optional): Number of input latent features
out_features (int, optional): Number of output features
num_hidden_layers (int, optional): Number of hidden layers
hidden_features (int, optional): Number of hidden features
outermost_linear (bool, optional): Whether to use linear layer at the end. Defaults to False.
nonlinearity (str, optional): Nonlinearity to use. Defaults to "sine".
weight_init (Callable, optional): Weight initialization function. Defaults to None.
bias_init (Callable, optional): Bias initialization function. Defaults to None.
premap_mode (str, optional): Feature mapping mode. Defaults to None.
Examples:
>>> model = ppsci.arch.SIRENAutodecoder_film(
input_keys=["input1", "input2"],
output_keys=("output",),
in_coord_features=2,
in_latent_features=128,
out_features=3,
num_hidden_layers=10,
hidden_features=128,
)
>>> input_data = {"input1": paddle.randn([10, 2]), "input2": paddle.randn([10, 128])}
>>> out_dict = model(input_data)
>>> for k, v in out_dict.items():
... print(k, v.shape)
output [22, 918, 3]
"""

def __init__(
self,
input_keys,
output_keys,
in_coord_features,
in_latent_features,
out_features,
num_hidden_layers,
hidden_features,
outermost_linear=False,
nonlinearity="sine",
weight_init=None,
bias_init=None,
premap_mode=None,
**kwargs,
):
super().__init__()
self.input_keys = input_keys
self.output_keys = output_keys

self.premap_mode = premap_mode
if self.premap_mode is not None:
self.premap_layer = FeatureMapping(
in_coord_features, mode=premap_mode, **kwargs
)
in_coord_features = self.premap_layer.dim
self.first_layer_init = None
self.nl, nl_weight_init, first_layer_init = NLS_AND_INITS[nonlinearity]
if weight_init is not None:
self.weight_init = weight_init
else:
self.weight_init = nl_weight_init
self.net1 = paddle.nn.LayerList(
sublayers=[BatchLinear(in_coord_features, hidden_features)]
+ [
BatchLinear(hidden_features, hidden_features)
for i in range(num_hidden_layers)
]
+ [BatchLinear(hidden_features, out_features)]
)
self.net2 = paddle.nn.LayerList(
sublayers=[
BatchLinear(in_latent_features, hidden_features, bias_attr=False)
for i in range(num_hidden_layers + 1)
]
)
if self.weight_init is not None:
self.net1.apply(self.weight_init)
self.net2.apply(self.weight_init)
if first_layer_init is not None:
self.net1[0].apply(first_layer_init)
self.net2[0].apply(first_layer_init)
if bias_init is not None:
self.net2.apply(bias_init)

def forward(self, input_data):
coords = input_data[self.input_keys[0]]
latents = input_data[self.input_keys[1]]
if self.premap_mode is not None:
x = self.premap_layer(coords)
else:
x = coords

for i in range(len(self.net1) - 1):
x = self.net1[i](x) + self.net2[i](latents)
x = self.nl(x)
x = self.net1[-1](x)
return {self.output_keys[0]: x}

def disable_gradient(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请看下是否用到?如果用到是否可以用no_grad代替?

for param in self.parameters():
param.stop_gradient = not False


class LatentContainer(paddle.nn.Layer):
"""
a model container that stores latents for multi GPU
Args:
input_key (Tuple[str, ...], optional): Key to get the input tensor from the dict. Defaults to ("intput",).
output_key (Tuple[str, ...], optional): Key to save the output tensor into the dict. Defaults to ("output",).
N_samples (int, optional): Number of samples. Defaults to None.
N_features (int, optional): Number of features. Defaults to None.
dims (int, optional): Number of dimensions. Defaults to None.
lumped (bool, optional): Whether to lump the latents. Defaults to False.
Examples:
>>> model = ppsci.arch.LatentContainer(N_samples=1600, N_features=128, dims=2, lumped=True)
>>> input_data = paddle.linspace(0, 1600, 1600, 'int64')
>>> input_dict = {"input": input_data}
>>> out_dict = model(input_dict)
>>> for k, v in out_dict.items():
... print(k, v.shape)
output [1600, 1, 128]
"""

def __init__(
self,
input_keys=("input",),
output_keys=("output",),
N_samples=None,
N_features=None,
dims=None,
lumped=False,
):
super().__init__()
self.input_keys = input_keys
self.output_keys = output_keys
self.dims = [1] * dims if not lumped else [1]
self.expand_dims = " ".join(["1" for _ in range(dims)]) if not lumped else "1"
self.expand_dims = f"N f -> N {self.expand_dims} f"
self.latents = self.create_parameter(
shape=(N_samples, N_features),
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0.0),
)

def forward(self, batch_ids):
x = batch_ids[self.input_keys[0]]
selected_latents = paddle.gather(self.latents, x)
if len(selected_latents.shape) > 1:
getShape = (
[tuple(selected_latents.shape)[0]]
+ self.dims
+ [tuple(selected_latents.shape)[1]]
)
else:
getShape = [-1] + self.dims
expanded_latents = selected_latents.reshape(getShape)
return {self.output_keys[0]: expanded_latents}