diff --git a/docs/zh/api/loss/loss.md b/docs/zh/api/loss/loss.md index f5d24edc39..2d533b1f27 100644 --- a/docs/zh/api/loss/loss.md +++ b/docs/zh/api/loss/loss.md @@ -11,6 +11,7 @@ - L2RelLoss - MAELoss - MSELoss + - CausalMSELoss - MSELossWithL2Decay - IntegralLoss - PeriodicL1Loss diff --git a/examples/allen_cahn/allen_cahn_causal.py b/examples/allen_cahn/allen_cahn_causal.py new file mode 100644 index 0000000000..89840177a4 --- /dev/null +++ b/examples/allen_cahn/allen_cahn_causal.py @@ -0,0 +1,303 @@ +""" +Reference: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/allen_cahn +""" + +from os import path as osp + +import hydra +import numpy as np +import paddle +import scipy.io as sio +from matplotlib import pyplot as plt +from omegaconf import DictConfig + +import ppsci +from ppsci.utils import misc + +dtype = paddle.get_default_dtype() + + +def plot( + t_star: np.ndarray, + x_star: np.ndarray, + u_ref: np.ndarray, + u_pred: np.ndarray, + output_dir: str, +): + fig = plt.figure(figsize=(18, 5)) + TT, XX = np.meshgrid(t_star, x_star, indexing="ij") + u_ref = u_ref.reshape([len(t_star), len(x_star)]) + + plt.subplot(1, 3, 1) + plt.pcolor(TT, XX, u_ref, cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Exact") + plt.tight_layout() + + plt.subplot(1, 3, 2) + plt.pcolor(TT, XX, u_pred, cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Predicted") + plt.tight_layout() + + plt.subplot(1, 3, 3) + plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Absolute error") + plt.tight_layout() + + fig_path = osp.join(output_dir, "ac.png") + print(f"Saving figure to {fig_path}") + fig.savefig(fig_path, bbox_inches="tight", dpi=400) + plt.close() + + +def train(cfg: DictConfig): + # set model + model = ppsci.arch.MLP(**cfg.MODEL) + + # set equation + equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)} + + data = sio.loadmat(cfg.DATA_PATH) + u_ref = data["usol"].astype(dtype) # (nt, nx) + t_star = data["t"].flatten().astype(dtype) # [nt, ] + x_star = data["x"].flatten().astype(dtype) # [nx, ] + + u0 = u_ref[0, :] # [nx, ] + + t0 = t_star[0] # float + t1 = t_star[-1] # float + + x0 = x_star[0] # float + x1 = x_star[-1] # float + + # set constraint + def gen_input_batch(): + tx = np.random.uniform( + [t0, x0], + [t1, x1], + (cfg.TRAIN.batch_size, 2), + ).astype(dtype) + return { + "t": np.sort(tx[:, 0:1], axis=0), + "x": tx[:, 1:2], + } + + def gen_label_batch(input_batch): + return {"allen_cahn": np.zeros([cfg.TRAIN.batch_size, 1], dtype)} + + pde_constraint = ppsci.constraint.SupervisedConstraint( + { + "dataset": { + "name": "ContinuousNamedArrayDataset", + "input": gen_input_batch, + "label": gen_label_batch, + }, + }, + output_expr=equation["AllenCahn"].equations, + loss=ppsci.loss.CausalMSELoss( + cfg.TRAIN.causal.n_chunks, "mean", tol=cfg.TRAIN.causal.tol + ), + name="PDE", + ) + + ic_input = {"t": np.full([len(x_star), 1], t0), "x": x_star.reshape([-1, 1])} + ic_label = {"u": u0.reshape([-1, 1])} + ic = ppsci.constraint.SupervisedConstraint( + { + "dataset": { + "name": "IterableNamedArrayDataset", + "input": ic_input, + "label": ic_label, + }, + }, + output_expr={"u": lambda out: out["u"]}, + loss=ppsci.loss.MSELoss("mean"), + name="IC", + ) + # wrap constraints together + constraint = { + pde_constraint.name: pde_constraint, + ic.name: ic, + } + + # set optimizer + lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay( + **cfg.TRAIN.lr_scheduler + )() + optimizer = ppsci.optimizer.Adam(lr_scheduler)(model) + + # set validator + tx_star = misc.cartesian_product(t_star, x_star).astype(dtype) + eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]} + eval_label = {"u": u_ref.reshape([-1, 1])} + u_validator = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": eval_data, + "label": eval_label, + }, + "batch_size": cfg.EVAL.batch_size, + }, + ppsci.loss.MSELoss("mean"), + {"u": lambda out: out["u"]}, + metric={"L2Rel": ppsci.metric.L2Rel()}, + name="u_validator", + ) + validator = {u_validator.name: u_validator} + + # initialize solver + solver = ppsci.solver.Solver( + model, + constraint, + cfg.output_dir, + optimizer, + lr_scheduler, + cfg.TRAIN.epochs, + cfg.TRAIN.iters_per_epoch, + save_freq=cfg.TRAIN.save_freq, + log_freq=cfg.log_freq, + eval_during_train=True, + eval_freq=cfg.TRAIN.eval_freq, + seed=cfg.seed, + equation=equation, + validator=validator, + pretrained_model_path=cfg.TRAIN.pretrained_model_path, + checkpoint_path=cfg.TRAIN.checkpoint_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + use_tbd=True, + cfg=cfg, + ) + # train model + solver.train() + # evaluate after finished training + solver.eval() + # visualize prediction after finished training + u_pred = solver.predict( + eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True + )["u"] + u_pred = u_pred.reshape([len(t_star), len(x_star)]) + + # plot + plot(t_star, x_star, u_ref, u_pred, cfg.output_dir) + + +def evaluate(cfg: DictConfig): + # set model + model = ppsci.arch.MLP(**cfg.MODEL) + + data = sio.loadmat(cfg.DATA_PATH) + u_ref = data["usol"].astype(dtype) # (nt, nx) + t_star = data["t"].flatten().astype(dtype) # [nt, ] + x_star = data["x"].flatten().astype(dtype) # [nx, ] + + # set validator + tx_star = misc.cartesian_product(t_star, x_star).astype(dtype) + eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]} + eval_label = {"u": u_ref.reshape([-1, 1])} + u_validator = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": eval_data, + "label": eval_label, + }, + "batch_size": cfg.EVAL.batch_size, + }, + ppsci.loss.MSELoss("mean"), + {"u": lambda out: out["u"]}, + metric={"L2Rel": ppsci.metric.L2Rel()}, + name="u_validator", + ) + validator = {u_validator.name: u_validator} + + # initialize solver + solver = ppsci.solver.Solver( + model, + output_dir=cfg.output_dir, + log_freq=cfg.log_freq, + validator=validator, + pretrained_model_path=cfg.EVAL.pretrained_model_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + + # evaluate after finished training + solver.eval() + # visualize prediction after finished training + u_pred = solver.predict( + eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True + )["u"] + u_pred = u_pred.reshape([len(t_star), len(x_star)]) + + # plot + plot(t_star, x_star, u_ref, u_pred, cfg.output_dir) + + +def export(cfg: DictConfig): + # set model + model = ppsci.arch.MLP(**cfg.MODEL) + + # initialize solver + solver = ppsci.solver.Solver( + model, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, + ] + solver.export(input_spec, cfg.INFER.export_path, with_onnx=False) + + +def inference(cfg: DictConfig): + from deploy.python_infer import pinn_predictor + + predictor = pinn_predictor.PINNPredictor(cfg) + data = sio.loadmat(cfg.DATA_PATH) + u_ref = data["usol"].astype(dtype) # (nt, nx) + t_star = data["t"].flatten().astype(dtype) # [nt, ] + x_star = data["x"].flatten().astype(dtype) # [nx, ] + tx_star = misc.cartesian_product(t_star, x_star).astype(dtype) + + input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]} + output_dict = predictor.predict(input_dict, cfg.INFER.batch_size) + output_dict = { + store_key: output_dict[infer_key] + for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys()) + } + u_pred = output_dict["u"].reshape([len(t_star), len(x_star)]) + # mapping data to cfg.INFER.output_keys + + plot(t_star, x_star, u_ref, u_pred, cfg.output_dir) + + +@hydra.main( + version_base=None, config_path="./conf", config_name="allen_cahn_causal.yaml" +) +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) + else: + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/allen_cahn/allen_cahn_plain.py b/examples/allen_cahn/allen_cahn_plain.py index f12cb13f49..e1cbcc10d3 100644 --- a/examples/allen_cahn/allen_cahn_plain.py +++ b/examples/allen_cahn/allen_cahn_plain.py @@ -1,5 +1,5 @@ """ -Reference: https://docs.nvidia.com/deeplearning/modulus/modulus-v2209/user_guide/intermediate/adding_stl_files.html +Reference: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/allen_cahn """ from os import path as osp @@ -53,6 +53,7 @@ def plot( plt.tight_layout() fig_path = osp.join(output_dir, "ac.png") + print(f"Saving figure to {fig_path}") fig.savefig(fig_path, bbox_inches="tight", dpi=400) plt.close() @@ -101,7 +102,7 @@ def gen_label_batch(input_batch): }, }, output_expr=equation["AllenCahn"].equations, - loss=ppsci.loss.MSELoss(), + loss=ppsci.loss.MSELoss("mean"), name="PDE", ) diff --git a/examples/allen_cahn/conf/allen_cahn.yaml b/examples/allen_cahn/conf/allen_cahn.yaml index 2facc7f966..b4d7004258 100644 --- a/examples/allen_cahn/conf/allen_cahn.yaml +++ b/examples/allen_cahn/conf/allen_cahn.yaml @@ -58,9 +58,6 @@ TRAIN: batch_size: 4096 pretrained_model_path: null checkpoint_path: null - ema: - decay: 0.9 - avg_freq: 1 # evaluation settings EVAL: diff --git a/examples/allen_cahn/conf/allen_cahn_causal_fourier_rwf.yaml b/examples/allen_cahn/conf/allen_cahn_causal_fourier_rwf.yaml new file mode 100644 index 0000000000..f664532b4f --- /dev/null +++ b/examples/allen_cahn/conf/allen_cahn_causal_fourier_rwf.yaml @@ -0,0 +1,93 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_allen_cahn_causal_fourier_rwf/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - INFER.pretrained_model_path + - mode + - output_dir + - log_freq + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 42 +output_dir: ${hydra:run.dir} +log_freq: 100 + +DATA_PATH: ./dataset/allen_cahn.mat + +# model settings +MODEL: + input_keys: [t, x] + output_keys: [u] + num_layers: 4 + hidden_size: 256 + activation: tanh + periods: + t: [2.0, False] + fourier: + dim: 256 + scale: 1.0 + random_weight: + mean: 0.5 + std: 0.1 + +# training settings +TRAIN: + epochs: 200 + iters_per_epoch: 1000 + save_freq: 10 + eval_during_train: true + eval_freq: 1 + lr_scheduler: + epochs: ${TRAIN.epochs} + iters_per_epoch: ${TRAIN.iters_per_epoch} + learning_rate: 1.0e-3 + gamma: 0.9 + decay_steps: 2000 + by_epoch: false + batch_size: 4096 + pretrained_model_path: null + checkpoint_path: null + causal: + n_chunks: 32 + tol: 1.0 + +# evaluation settings +EVAL: + pretrained_model_path: null + eval_with_no_grad: true + batch_size: 4096 + +# inference settings +INFER: + pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_plain_pretrained.pdparams + export_path: ./inference/allen_cahn + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + onnx_path: ${INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + batch_size: 1024 diff --git a/ppsci/arch/mlp.py b/ppsci/arch/mlp.py index 0816556b8e..102dee3f69 100644 --- a/ppsci/arch/mlp.py +++ b/ppsci/arch/mlp.py @@ -14,6 +14,7 @@ from __future__ import annotations +import math from typing import Dict from typing import Optional from typing import Tuple @@ -53,6 +54,50 @@ def forward(self, input): return nn.functional.linear(input, weight, self.bias) +class RandomWeightFactorization(nn.Layer): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + mean: float = 0.5, + std: float = 0.1, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight_v = self.create_parameter((in_features, out_features)) + self.weight_g = self.create_parameter((out_features,)) + if bias: + self.bias = self.create_parameter((out_features,)) + else: + self.bias = None + + self._init_weights(mean, std) + + def _init_weights(self, mean, std): + with paddle.no_grad(): + # glorot normal + fin, fout = self.weight_v.shape + var = 2.0 / (fin + fout) + stddev = math.sqrt(var) * 0.87962566103423978 + initializer.trunc_normal_(self.weight_v) + paddle.assign(self.weight_v * stddev, self.weight_v) + + nn.initializer.Normal(mean, std)(self.weight_g) + paddle.assign(paddle.exp(self.weight_g), self.weight_g) + paddle.assign(self.weight_v / self.weight_g, self.weight_v) + if self.bias is not None: + initializer.constant_(self.bias, 0.0) + + self.weight_g.stop_gradient = False + self.weight_v.stop_gradient = False + self.bias.stop_gradient = False + + def forward(self, input): + return nn.functional.linear(input, self.weight_g * self.weight_v, self.bias) + + class PeriodEmbedding(nn.Layer): def __init__(self, periods: Dict[str, Tuple[float, bool]]): super().__init__() @@ -60,7 +105,7 @@ def __init__(self, periods: Dict[str, Tuple[float, bool]]): k: self.create_parameter( [], attr=paddle.ParamAttr(trainable=trainable), - default_initializer=nn.initializer.Constant(2 * np.pi / p), + default_initializer=nn.initializer.Constant(2 * np.pi / eval(p)), ) # mu = 2*pi / period for sin/cos function for k, (p, trainable) in periods.items() } @@ -75,6 +120,28 @@ def forward(self, x: Dict[str, paddle.Tensor]): return y +class FourierEmbedding(nn.Layer): + def __init__(self, in_features, out_features, scale): + super().__init__() + if out_features % 2 != 0: + raise ValueError(f"out_features must be even, but got {out_features}.") + + self.kernel = self.create_parameter( + [in_features, out_features // 2], + default_initializer=nn.initializer.Normal(std=scale), + ) + + def forward(self, x: paddle.Tensor): + y = paddle.concat( + [ + paddle.cos(x @ self.kernel), + paddle.sin(x @ self.kernel), + ], + axis=-1, + ) + return y + + class MLP(base.Arch): """Multi layer perceptron network. @@ -92,6 +159,10 @@ class MLP(base.Arch): periods (Optional[Dict[int, Tuple[float, bool]]]): Period of each input key, input in given channel will be period embeded if specified, each tuple of periods list is [period, trainable]. Defaults to None. + fourier (Optional[Dict[str, Union[float, int]]]): Random fourier feature embedding, + e.g. {'dim': 256, 'sclae': 1.0}. Defaults to None. + random_weight (Optional[Dict[str, float]]): Mean and std of random weight + factorization layer, e.g. {"mean": 0.5, "std: 0.1"}. Defaults to None. Examples: >>> import paddle @@ -122,7 +193,9 @@ def __init__( weight_norm: bool = False, input_dim: Optional[int] = None, output_dim: Optional[int] = None, - periods: Dict[int, Tuple[float, bool]] = None, + periods: Optional[Dict[int, Tuple[float, bool]]] = None, + fourier: Optional[Dict[str, Union[float, int]]] = None, + random_weight: Optional[Dict[str, float]] = None, ): super().__init__() self.input_keys = input_keys @@ -130,6 +203,7 @@ def __init__( self.linears = [] self.acts = [] self.periods = periods + self.fourier = fourier if periods: self.period_emb = PeriodEmbedding(periods) @@ -156,12 +230,27 @@ def __init__( # if input_dim is not specified cur_size += len(periods) - for i, _size in enumerate(hidden_size): - self.linears.append( - WeightNormLinear(cur_size, _size) - if weight_norm - else nn.Linear(cur_size, _size) + if fourier: + self.fourier_emb = FourierEmbedding( + cur_size, fourier["dim"], fourier["scale"] ) + cur_size = fourier["dim"] + + for i, _size in enumerate(hidden_size): + if weight_norm: + self.linears.append(WeightNormLinear(cur_size, _size)) + elif random_weight: + self.linears.append( + RandomWeightFactorization( + cur_size, + _size, + mean=random_weight["mean"], + std=random_weight["std"], + ) + ) + else: + self.linears.append(nn.Linear(cur_size, _size)) + # initialize activation function self.acts.append( act_mod.get_activation(activation) @@ -180,10 +269,18 @@ def __init__( self.linears = nn.LayerList(self.linears) self.acts = nn.LayerList(self.acts) - self.last_fc = nn.Linear( - cur_size, - len(self.output_keys) if output_dim is None else output_dim, - ) + if random_weight: + self.last_fc = RandomWeightFactorization( + cur_size, + len(self.output_keys) if output_dim is None else output_dim, + mean=random_weight["mean"], + std=random_weight["std"], + ) + else: + self.last_fc = nn.Linear( + cur_size, + len(self.output_keys) if output_dim is None else output_dim, + ) self.skip_connection = skip_connection @@ -212,6 +309,10 @@ def forward(self, x): x = self.period_emb(x) y = self.concat_to_tensor(x, self.input_keys, axis=-1) + + if self.fourier: + y = self.fourier_emb(y) + y = self.forward_tensor(y) y = self.split_to_dict(y, self.output_keys, axis=-1) diff --git a/ppsci/autodiff/ad.py b/ppsci/autodiff/ad.py index 55453aa07a..1214e83314 100644 --- a/ppsci/autodiff/ad.py +++ b/ppsci/autodiff/ad.py @@ -303,12 +303,6 @@ def _clear(self): def clear(): """Clear cached Jacobians and Hessians. - Args: - None. - - Returns: - None. - Examples: >>> import paddle >>> import ppsci diff --git a/ppsci/loss/__init__.py b/ppsci/loss/__init__.py index cade922dd8..26332da8be 100644 --- a/ppsci/loss/__init__.py +++ b/ppsci/loss/__init__.py @@ -24,6 +24,7 @@ from ppsci.loss.l2 import L2RelLoss from ppsci.loss.l2 import PeriodicL2Loss from ppsci.loss.mae import MAELoss +from ppsci.loss.mse import CausalMSELoss from ppsci.loss.mse import MSELoss from ppsci.loss.mse import MSELossWithL2Decay from ppsci.loss.mse import PeriodicMSELoss @@ -38,6 +39,7 @@ "L2RelLoss", "PeriodicL2Loss", "MAELoss", + "CausalMSELoss", "MSELoss", "MSELossWithL2Decay", "PeriodicMSELoss", diff --git a/ppsci/loss/mse.py b/ppsci/loss/mse.py index 03152a1c14..411e69a126 100644 --- a/ppsci/loss/mse.py +++ b/ppsci/loss/mse.py @@ -18,6 +18,7 @@ from typing import Optional from typing import Union +import paddle import paddle.nn.functional as F from typing_extensions import Literal @@ -99,6 +100,85 @@ def forward(self, output_dict, label_dict, weight_dict=None): return losses +class CausalMSELoss(base.Loss): + r"""Class for mean squared error loss. + + $$ + L = \frac{1}{M} \displaystyle\sum_{i=1}^M{w_i} \mathcal{L}_r^i, + $$ + + where $w_i=\exp (-\epsilon \displaystyle\sum_{k=1}^{i-1} \mathcal{L}_r^k), i=2,3, \ldots, M.$ + + Args: + n_chunks (int): Number of time windows split. + reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean". + weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None. + tol (float, optional): Causal tolerance, i.e. $\epsilon$ in paper. Defaults to 1.0. + + Examples: + >>> import paddle + >>> from ppsci.loss import MSELoss + + >>> output_dict = {'u': paddle.to_tensor([[0.5, 0.9, 1.0], [1.1, -1.3, 0.0]])} + >>> label_dict = {'u': paddle.to_tensor([[-1.8, 1.0, -0.1], [-0.2, 2.5, 2.0]])} + >>> loss = CausalMSELoss(n_chunks=3) + >>> result = loss(output_dict, label_dict) + >>> print(result) + Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, + 0.96841478) + """ + + def __init__( + self, + n_chunks: int, + reduction: Literal["mean", "sum"] = "mean", + weight: Optional[Union[float, Dict[str, float]]] = None, + tol: float = 1.0, + ): + if n_chunks <= 0: + raise ValueError(f"n_chunks should be positive, but got {n_chunks}") + if reduction not in ["mean", "sum"]: + raise ValueError( + f"reduction should be 'mean' or 'sum', but got {reduction}" + ) + super().__init__(reduction, weight) + self.n_chunks = n_chunks + self.tol = tol + self.register_buffer( + "acc_mat", paddle.tril(paddle.ones([n_chunks, n_chunks]), -1) + ) + + def forward(self, output_dict, label_dict, weight_dict=None): + losses = 0.0 + for key in label_dict: + loss = F.mse_loss(output_dict[key], label_dict[key], "none") + if weight_dict and key in weight_dict: + loss *= weight_dict[key] + + if "area" in output_dict: + loss *= output_dict["area"] + + # causal weighting + loss_t = loss.reshape([self.n_chunks, -1]) # [nt, nx] + weight_t = paddle.exp( + -self.tol * (self.acc_mat @ loss_t.mean(-1, keepdim=True)) + ) # [nt, nt] x [nt, 1] ==> [nt, 1] + assert weight_t.shape[0] == self.n_chunks + loss = loss_t * weight_t.detach() + + if self.reduction == "sum": + loss = loss.sum() + elif self.reduction == "mean": + loss = loss.mean() + if isinstance(self.weight, (float, int)): + loss *= self.weight + elif isinstance(self.weight, dict) and key in self.weight: + loss *= self.weight[key] + + losses += loss + return losses + + class MSELossWithL2Decay(MSELoss): r"""MSELoss with L2 decay. diff --git a/ppsci/utils/ema.py b/ppsci/utils/ema.py index 7db2ba322a..2664b2184b 100644 --- a/ppsci/utils/ema.py +++ b/ppsci/utils/ema.py @@ -154,10 +154,10 @@ class StochasticWeightAverage(AveragedModel): All parameters are updated by the formula as below: $$ - \mathbf{\theta}_{EMA}^{t+1} = \alpha \mathbf{\theta}_{EMA}^{t} + (1 - \alpha) \mathbf{\theta}^{t} + \mathbf{\theta}_{SWA}^{t} = \frac{1}{t-t_0+1}\sum_{i=t_0}^t{\mathbf{\theta}^{i}} $$ - Where $\theta_{EMA}^{t}$ is the moving average parameters and $\theta^{t}$ is the online parameters at step $t$. + Where $\theta_{SWA}^{t}$ is the average parameters between step $t_0$ and $t$, $\theta^{i}$ is the online parameters at step $i$. Args: model (nn.Layer): The model to be averaged. @@ -165,6 +165,7 @@ class StochasticWeightAverage(AveragedModel): def __init__(self, model: nn.Layer): super().__init__(model, None) + self.n_avg += 1 # Set to 1 for model already initialized def _update_fn_(self, shadow_param, model_param, step): dynamic_decay = step / (step + 1)